Skip to content

Commit

Permalink
Merge pull request #395 from mgxd/fix/age-parsing
Browse files Browse the repository at this point in the history
FIX/ENH: Improvements to age parsing
  • Loading branch information
mgxd authored Sep 20, 2024
2 parents 9b149e4 + 6339bd3 commit d24cd43
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 57 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ repos:
- id: check-toml
- id: check-added-large-files
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.0
rev: v0.6.5
hooks:
- id: ruff
args: [ --fix ]
- id: ruff-format
24 changes: 0 additions & 24 deletions nibabies/_warnings.py

This file was deleted.

48 changes: 33 additions & 15 deletions nibabies/utils/bids.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
import os
import sys
import typing as ty
import warnings
from pathlib import Path

import pandas as pd

SUPPORTED_AGE_UNITS = (
'weeks',
'months',
Expand Down Expand Up @@ -244,7 +247,11 @@ def parse_bids_for_age_months(

scans_tsv = session_level / f'{prefix}_scans.tsv'
if scans_tsv.exists():
age = _get_age_from_tsv(scans_tsv)
age = _get_age_from_tsv(
scans_tsv,
index_column='filename',
index_val=r'^anat.*',
)

if age is not None:
return age
Expand All @@ -258,16 +265,18 @@ def parse_bids_for_age_months(

participants_tsv = Path(bids_root) / 'participants.tsv'
if participants_tsv.exists() and age is None:
age = _get_age_from_tsv(participants_tsv, index_column='participant_id', index_val=subject)
age = _get_age_from_tsv(
participants_tsv, index_column='participant_id', index_value=subject
)

return age


def _get_age_from_tsv(
bids_tsv: Path, index_column: str | None = None, index_val: str | None = None
) -> int | None:
import pandas as pd

bids_tsv: Path,
index_column: str | None = None,
index_value: str | None = None,
) -> float | None:
df = pd.read_csv(str(bids_tsv), sep='\t')
age_col = None

Expand All @@ -278,14 +287,18 @@ def _get_age_from_tsv(
if age_col is None:
return

if not index_column or not index_val: # Just grab first value
idx = df.index[0]
else:
idx = df.index[df[index_column] == index_val].item()
df = df[df[index_column].str.fullmatch(index_value)]

# Multiple indices may be present after matching
if len(df) > 1:
warnings.warn(
f'Multiple matches for {index_column}:{index_value} found in {bids_tsv.name}.',
stacklevel=1,
)

try:
# extract age value from row
age = int(df.loc[idx, age_col].item())
age = float(df.loc[df.index[0], age_col].item())
except Exception: # noqa: BLE001
return

Expand All @@ -294,7 +307,10 @@ def _get_age_from_tsv(
bids_json = bids_tsv.with_suffix('.json')
age_units = _get_age_units(bids_json)
if age_units is False:
return None
raise FileNotFoundError(
f'Could not verify age unit for {bids_tsv.name} - ensure a sidecar JSON '
'describing column `age` units is available.'
)
else:
age_units = age_col.split('_')[-1]

Expand All @@ -318,12 +334,14 @@ def _get_age_units(bids_json: Path) -> ty.Literal['weeks', 'months', 'years', Fa
return False


def age_to_months(age: int, units: ty.Literal['weeks', 'months', 'years']) -> int:
def age_to_months(age: int | float, units: ty.Literal['weeks', 'months', 'years']) -> int:
"""
Convert a given age, in either "weeks", "months", or "years", into months.
>>> age_to_months(1, 'years')
12
>>> age_to_months(0.5, 'years')
6
>>> age_to_months(2, 'weeks')
0
>>> age_to_months(3, 'weeks')
Expand All @@ -335,7 +353,7 @@ def age_to_months(age: int, units: ty.Literal['weeks', 'months', 'years']) -> in
YEARS_TO_MONTH = 12

if units == 'weeks':
age = round(age * WEEKS_TO_MONTH)
age *= WEEKS_TO_MONTH
elif units == 'years':
age *= YEARS_TO_MONTH
return age
return int(round(age))
63 changes: 46 additions & 17 deletions nibabies/utils/tests/test_bids.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,59 @@ def create_sidecar(tsv_file: Path, units) -> None:
age_months = {'age_months': [3, 6, 9]}
age_years = {'age_years': [1, 1, 2]}

participants = {'participant_id': ['sub-1', 'sub-2', 'sub-11']}
sessions = {'session_id': ['ses-1', 'ses-2', 'ses-3']}
scans = {
'filename': [
'dwi/sub-01_dwi.nii.gz',
'anat/sub-01_T1w.nii.gz',
'func/sub-01_task-rest_bold.nii.gz',
]
}


@pytest.mark.parametrize(
('idx_col', 'idx_val', 'data', 'sidecar', 'expected'),
('idx_col', 'idx_val', 'data', 'units', 'expected'),
[
('session_id', 'x1', age, False, None),
('session_id', 'x1', age, 'months', 4),
('session_id', 'x1', age, 'weeks', 1), # Convert from 4 weeks -> 1 month
('session_id', 'x1', age, ['months', 'weeks'], None),
('session_id', 'x2', age_weeks, False, 2),
('participant_id', 'x1', age_months, False, 3),
('participant_id', 'x3', age_years, False, 24),
('session_id', 'x3', {**age_months, **age}, False, 9),
(None, None, age_months, False, 3),
('session_id', 'ses-1', age, 'months', 4),
('session_id', 'ses-1', age, 'weeks', 1), # Convert from 4 weeks -> 1 month
('session_id', 'ses-2', age_weeks, False, 2),
('participant_id', 'sub-1', age_months, False, 3),
('participant_id', 'sub-11', age_years, False, 24),
('session_id', 'ses-3', {**age_months, **age}, False, 9),
('filename', r'^anat.*', age_months, False, 6),
],
)
def test_get_age_from_tsv(tmp_path, idx_col, idx_val, data, sidecar, expected):
def test_get_age_from_tsv(tmp_path, idx_col, idx_val, data, units, expected):
tsv_file = tmp_path / 'test-age-parsing.tsv'
base = {}
if idx_col is not None:
base[idx_col] = ['x1', 'x2', 'x3']
create_tsv({**base, **data}, tsv_file)

if sidecar:
create_sidecar(tsv_file, sidecar)
if idx_col == 'participant_id':
base = participants
elif idx_col == 'session_id':
base = sessions
elif idx_col == 'filename':
base = scans

create_tsv({**base, **data}, tsv_file)
if units:
create_sidecar(tsv_file, units)

res = _get_age_from_tsv(tsv_file, idx_col, idx_val)
assert res == expected


def test_get_age_from_tsv_error(tmp_path):
tsv_file = tmp_path / 'participants.tsv'

create_tsv({**participants, **age}, tsv_file)
with pytest.raises(FileNotFoundError):
_get_age_from_tsv(tsv_file, 'participant_id', 'sub-1')


def test_get_age_from_tsv_warning(tmp_path):
tsv_file = tmp_path / 'participants.tsv'
dual_participants = {'participant_id': ['sub-1', 'sub-2', 'sub-2']}
create_tsv({**dual_participants, **age_months}, tsv_file)

with pytest.warns(UserWarning):
_get_age_from_tsv(tsv_file, 'participant_id', 'sub-2')

0 comments on commit d24cd43

Please sign in to comment.