Skip to content

Commit

Permalink
ENH: Make binning strategy more robust
Browse files Browse the repository at this point in the history
Prefer capping b-values to a sensible maximum default value to make the binning strategy more robust.

Raise a `ValueError` users when no actual shells are found.

Return the median estimated b-vals for each bin.

Co-authored-by: Oscar Esteban <[email protected]>
  • Loading branch information
jhlegarreta and oesteban committed Jun 12, 2024
1 parent b41a427 commit d221f03
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 13 deletions.
24 changes: 17 additions & 7 deletions src/eddymotion/model/dmri_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,15 @@
DEFAULT_MULTISHELL_BIN_COUNT_THR = 7
"""Default bin count to consider a multishell scheme."""

DEFAULT_MAX_BVAL = 8000
"""Maximum b-value cap."""


def find_shelling_scheme(
bvals,
num_bins=DEFAULT_NUM_BINS,
multishell_nonempty_bin_count_thr=DEFAULT_MULTISHELL_BIN_COUNT_THR,
bval_cap=DEFAULT_MAX_BVAL,
):
"""
Find the shelling scheme on the given b-values.
Expand All @@ -57,28 +61,34 @@ def find_shelling_scheme(
Shelling scheme.
bval_groups : :obj:`list`
List of grouped b-values.
bval_estimated : :obj:`list`
List of 'estimated' b-values as the median value of each b-value group.
"""

# Bin the b-values: use -1 as the lower bound to be able to appropriately
# include b0 values
bins = np.linspace(-1, max(bvals), num_bins + 1)
hist, bin_edges = np.histogram(bvals, bins=bins)
hist, bin_edges = np.histogram(bvals, bins=num_bins, range=(-1, min(max(bvals), bval_cap)))

# Collect values in each bin
bval_groups = []
bval_estimated = []
for lower, upper in zip(bin_edges[:-1], bin_edges[1:], strict=False):
bval_groups.append(bvals[(bvals > lower) & (bvals <= upper)])

# Remove empty bins from the list
bval_groups = [v for v in bval_groups if len(v)]
# Add only if a nonempty b-values mask
if (mask := (bvals > lower) & (bvals <= upper)).sum():
bval_groups.append(bvals[mask])
bval_estimated.append(np.median(bvals[mask]))

nonempty_bins = len(bval_groups)

if nonempty_bins < 2:
raise ValueError("DWI must have at least one high-b shell")

if nonempty_bins == 2:
scheme = "single-shell"
elif nonempty_bins < multishell_nonempty_bin_count_thr:
scheme = "multi-shell"
else:
scheme = "DSI"

return scheme, bval_groups
return scheme, bval_groups, bval_estimated
32 changes: 26 additions & 6 deletions src/eddymotion/model/tests/test_dmri_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@


@pytest.mark.parametrize(
("bvals", "exp_scheme", "exp_bval_groups"),
("bvals", "exp_scheme", "exp_bval_groups", "exp_bval_estimated"),
[
(
np.asarray(
Expand Down Expand Up @@ -175,20 +175,22 @@
]
),
],
[5, 300, 1000, 2000],
),
],
)
def test_find_shelling_scheme_array(bvals, exp_scheme, exp_bval_groups):
obt_scheme, obt_bval_groups = find_shelling_scheme(bvals)
def test_find_shelling_scheme_array(bvals, exp_scheme, exp_bval_groups, exp_bval_estimated):
obt_scheme, obt_bval_groups, obt_bval_estimated = find_shelling_scheme(bvals)
assert obt_scheme == exp_scheme
assert all(
np.allclose(obt_arr, exp_arr)
for obt_arr, exp_arr in zip(obt_bval_groups, exp_bval_groups, strict=True)
)
assert np.allclose(obt_bval_estimated, exp_bval_estimated)


@pytest.mark.parametrize(
("dwi_btable", "exp_scheme", "exp_bval_groups"),
("dwi_btable", "exp_scheme", "exp_bval_groups", "exp_bval_estimated"),
[
(
"ds000114_singleshell",
Expand Down Expand Up @@ -264,6 +266,7 @@ def test_find_shelling_scheme_array(bvals, exp_scheme, exp_bval_groups):
]
),
],
[0.0, 1000.0],
),
(
"hcph_multishell",
Expand Down Expand Up @@ -546,6 +549,7 @@ def test_find_shelling_scheme_array(bvals, exp_scheme, exp_bval_groups):
]
),
],
[0.0, 700.0, 1000.0, 2000.0, 3000.0],
),
(
"ds004737_dsi",
Expand Down Expand Up @@ -606,15 +610,31 @@ def test_find_shelling_scheme_array(bvals, exp_scheme, exp_bval_groups):
]
),
],
[
5.0,
995.0,
1195.0,
1595.0,
1797.5,
2190.0,
2595.0,
2795.0,
3400.0,
3790.0,
4195.0,
4390.0,
4990.0,
],
),
],
)
def test_find_shelling_scheme_files(dwi_btable, exp_scheme, exp_bval_groups):
def test_find_shelling_scheme_files(dwi_btable, exp_scheme, exp_bval_groups, exp_bval_estimated):
bvals = np.loadtxt(_datadir / f"{dwi_btable}.bval")

obt_scheme, obt_bval_groups = find_shelling_scheme(bvals)
obt_scheme, obt_bval_groups, obt_bval_estimated = find_shelling_scheme(bvals)
assert obt_scheme == exp_scheme
assert all(
np.allclose(obt_arr, exp_arr)
for obt_arr, exp_arr in zip(obt_bval_groups, exp_bval_groups, strict=True)
)
assert np.allclose(obt_bval_estimated, exp_bval_estimated)

0 comments on commit d221f03

Please sign in to comment.