From d221f03a19e3b9d897691957e97bcb4ced86e71b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Wed, 12 Jun 2024 17:53:19 +0200 Subject: [PATCH] ENH: Make binning strategy more robust Prefer capping b-values to a sensible maximum default value to make the binning strategy more robust. Raise a `ValueError` users when no actual shells are found. Return the median estimated b-vals for each bin. Co-authored-by: Oscar Esteban --- src/eddymotion/model/dmri_utils.py | 24 ++++++++++---- src/eddymotion/model/tests/test_dmri_utils.py | 32 +++++++++++++++---- 2 files changed, 43 insertions(+), 13 deletions(-) diff --git a/src/eddymotion/model/dmri_utils.py b/src/eddymotion/model/dmri_utils.py index a0774f86..079bbcc5 100644 --- a/src/eddymotion/model/dmri_utils.py +++ b/src/eddymotion/model/dmri_utils.py @@ -28,11 +28,15 @@ DEFAULT_MULTISHELL_BIN_COUNT_THR = 7 """Default bin count to consider a multishell scheme.""" +DEFAULT_MAX_BVAL = 8000 +"""Maximum b-value cap.""" + def find_shelling_scheme( bvals, num_bins=DEFAULT_NUM_BINS, multishell_nonempty_bin_count_thr=DEFAULT_MULTISHELL_BIN_COUNT_THR, + bval_cap=DEFAULT_MAX_BVAL, ): """ Find the shelling scheme on the given b-values. @@ -57,23 +61,29 @@ def find_shelling_scheme( Shelling scheme. bval_groups : :obj:`list` List of grouped b-values. + bval_estimated : :obj:`list` + List of 'estimated' b-values as the median value of each b-value group. + """ # Bin the b-values: use -1 as the lower bound to be able to appropriately # include b0 values - bins = np.linspace(-1, max(bvals), num_bins + 1) - hist, bin_edges = np.histogram(bvals, bins=bins) + hist, bin_edges = np.histogram(bvals, bins=num_bins, range=(-1, min(max(bvals), bval_cap))) # Collect values in each bin bval_groups = [] + bval_estimated = [] for lower, upper in zip(bin_edges[:-1], bin_edges[1:], strict=False): - bval_groups.append(bvals[(bvals > lower) & (bvals <= upper)]) - - # Remove empty bins from the list - bval_groups = [v for v in bval_groups if len(v)] + # Add only if a nonempty b-values mask + if (mask := (bvals > lower) & (bvals <= upper)).sum(): + bval_groups.append(bvals[mask]) + bval_estimated.append(np.median(bvals[mask])) nonempty_bins = len(bval_groups) + if nonempty_bins < 2: + raise ValueError("DWI must have at least one high-b shell") + if nonempty_bins == 2: scheme = "single-shell" elif nonempty_bins < multishell_nonempty_bin_count_thr: @@ -81,4 +91,4 @@ def find_shelling_scheme( else: scheme = "DSI" - return scheme, bval_groups + return scheme, bval_groups, bval_estimated diff --git a/src/eddymotion/model/tests/test_dmri_utils.py b/src/eddymotion/model/tests/test_dmri_utils.py index eada8b70..c6c475e7 100644 --- a/src/eddymotion/model/tests/test_dmri_utils.py +++ b/src/eddymotion/model/tests/test_dmri_utils.py @@ -33,7 +33,7 @@ @pytest.mark.parametrize( - ("bvals", "exp_scheme", "exp_bval_groups"), + ("bvals", "exp_scheme", "exp_bval_groups", "exp_bval_estimated"), [ ( np.asarray( @@ -175,20 +175,22 @@ ] ), ], + [5, 300, 1000, 2000], ), ], ) -def test_find_shelling_scheme_array(bvals, exp_scheme, exp_bval_groups): - obt_scheme, obt_bval_groups = find_shelling_scheme(bvals) +def test_find_shelling_scheme_array(bvals, exp_scheme, exp_bval_groups, exp_bval_estimated): + obt_scheme, obt_bval_groups, obt_bval_estimated = find_shelling_scheme(bvals) assert obt_scheme == exp_scheme assert all( np.allclose(obt_arr, exp_arr) for obt_arr, exp_arr in zip(obt_bval_groups, exp_bval_groups, strict=True) ) + assert np.allclose(obt_bval_estimated, exp_bval_estimated) @pytest.mark.parametrize( - ("dwi_btable", "exp_scheme", "exp_bval_groups"), + ("dwi_btable", "exp_scheme", "exp_bval_groups", "exp_bval_estimated"), [ ( "ds000114_singleshell", @@ -264,6 +266,7 @@ def test_find_shelling_scheme_array(bvals, exp_scheme, exp_bval_groups): ] ), ], + [0.0, 1000.0], ), ( "hcph_multishell", @@ -546,6 +549,7 @@ def test_find_shelling_scheme_array(bvals, exp_scheme, exp_bval_groups): ] ), ], + [0.0, 700.0, 1000.0, 2000.0, 3000.0], ), ( "ds004737_dsi", @@ -606,15 +610,31 @@ def test_find_shelling_scheme_array(bvals, exp_scheme, exp_bval_groups): ] ), ], + [ + 5.0, + 995.0, + 1195.0, + 1595.0, + 1797.5, + 2190.0, + 2595.0, + 2795.0, + 3400.0, + 3790.0, + 4195.0, + 4390.0, + 4990.0, + ], ), ], ) -def test_find_shelling_scheme_files(dwi_btable, exp_scheme, exp_bval_groups): +def test_find_shelling_scheme_files(dwi_btable, exp_scheme, exp_bval_groups, exp_bval_estimated): bvals = np.loadtxt(_datadir / f"{dwi_btable}.bval") - obt_scheme, obt_bval_groups = find_shelling_scheme(bvals) + obt_scheme, obt_bval_groups, obt_bval_estimated = find_shelling_scheme(bvals) assert obt_scheme == exp_scheme assert all( np.allclose(obt_arr, exp_arr) for obt_arr, exp_arr in zip(obt_bval_groups, exp_bval_groups, strict=True) ) + assert np.allclose(obt_bval_estimated, exp_bval_estimated)