Skip to content

Commit

Permalink
Merge pull request #305 from nipreps/docker/t1-t2-derivatives
Browse files Browse the repository at this point in the history
ENH+RF: Allow precomputed derivatives in T1w or T2w space
  • Loading branch information
mgxd authored Aug 30, 2023
2 parents 45d48ef + b6838ef commit 131f0d5
Show file tree
Hide file tree
Showing 10 changed files with 591 additions and 254 deletions.
185 changes: 122 additions & 63 deletions nibabies/utils/bids.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,134 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
"""Utilities to handle BIDS inputs."""
from __future__ import annotations

import json
import os
import sys
import typing as ty
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import IO, List, Literal, Optional, Union

import nibabel as nb
import numpy as np
from bids.layout import BIDSLayout, Query

_spec: dict = {
't1w_mask': {
'datatype': 'anat',
'desc': 'brain',
'space': 'T1w',
'suffix': 'mask',
},
't1w_aseg': {'datatype': 'anat', 'desc': 'aseg', 'space': 'T1w', 'suffix': 'dseg'},
't2w_mask': {
'datatype': 'anat',
'desc': 'brain',
'space': 'T2w',
'suffix': 'mask',
},
't2w_aseg': {
'datatype': 'anat',
'desc': 'aseg',
'space': 'T2w',
'suffix': 'dseg',
},
}


class Derivatives:
"""
A container class for collecting and storing derivatives.
A specification (either dictionary or JSON file) can be used to customize derivatives and
queries.
To populate this class with derivatives, the `populate()` method must first be called.
"""

@dataclass
class BOLDGrouping:
"""This class is used to facilitate the grouping of BOLD series."""
def __getattribute__(self, attr):
"""In cases where the spec may change, avoid errors."""
try:
return object.__getattribute__(self, attr)
except AttributeError:
return None

def __init__(self, bids_root: Path | str, spec: dict | Path | str | None = None, **args):
self.bids_root = Path(bids_root)
self.spec = _spec
if spec is not None:
if not isinstance(spec, dict):
spec: dict = json.loads(Path(spec).read_text())
self.spec = spec

self.names = set(self.spec.keys())
self.references = {name: None for name in self.names}
for name in self.names:
setattr(self, name, None)

def __repr__(self):
return '\n'.join([name for name in self.names if getattr(self, name)])

def __contains__(self, val: str):
return val in self.names

def __bool__(self):
return any(getattr(self, name) for name in self.names)

def populate(
self, deriv_path, subject_id: str, session_id: str | Query | None = Query.OPTIONAL
) -> None:
"""Query a derivatives directory and populate values and references based on the spec."""
layout = BIDSLayout(deriv_path, validate=False)
for name, query in self.spec.items():
items = layout.get(
subject=subject_id,
session=session_id,
extension=['.nii', '.nii.gz'],
**query,
)
if not items or len(items) > 1:
warnings.warn(f"Could not find {name}")
continue
item = items[0]

# Skip if derivative does not have valid metadata
metadata = item.get_metadata()
if not metadata or not (reference := metadata.get('SpatialReference')):
warnings.warn(f"No metadata found for {item}")
continue
if isinstance(reference, list):
if len(reference) > 1:
warnings.warn(f"Multiple reference found: {reference}")
continue
reference = reference[0]

reference = self.bids_root / reference
if not self.validate(item.path, str(reference)):
warnings.warn(f"Validation failed between: {item.path} and {reference}")
continue

setattr(self, name, Path(item.path))
self.references[name] = reference

session: Union[str, None]
pe_dir: str
readout: float
multiecho_id: str = None
files: List[IO] = field(default_factory=list)
@property
def mask(self) -> str | None:
return self.t1w_mask or self.t2w_mask

@property
def name(self) -> str:
return f"{self.session}-{self.pe_dir}-{self.readout}-{self.multiecho_id}"
def aseg(self) -> str | None:
return self.t1w_aseg or self.t2w_aseg

def add_file(self, fl) -> None:
self.files.append(fl)
@staticmethod
def validate(derivative: str, reference: str, atol: float = 1e-5) -> bool:
anat = nb.load(reference)
expected_ort = nb.aff2axcodes(anat.affine)
img = nb.load(derivative)
if nb.aff2axcodes(img.affine) != expected_ort:
return False
if img.shape != anat.shape or not np.allclose(anat.affine, img.affine, atol=atol):
return False
return True


def write_bidsignore(deriv_dir):
Expand Down Expand Up @@ -221,55 +324,11 @@ def validate_input_dir(exec_env, bids_dir, participant_label):
print("bids-validator does not appear to be installed", file=sys.stderr)


def collect_precomputed_derivatives(layout, subject_id, derivatives_filters=None):
"""
Query and collect precomputed derivatives.
This function is used to determine which workflow steps can be skipped,
based on the files found.
"""

deriv_queries = {
'anat_mask': {
'datatype': 'anat',
'desc': 'brain',
'space': 'orig',
'suffix': 'mask',
},
'anat_aseg': {
'datatype': 'anat',
'desc': 'aseg',
'space': 'orig',
'suffix': 'dseg',
},
}
if derivatives_filters is not None:
deriv_queries.update(derivatives_filters)

derivatives = {}
for deriv, query in deriv_queries.items():
res = layout.get(
scope='derivatives',
subject=subject_id,
extension=['.nii', '.nii.gz'],
return_type="filename",
**query,
)
if not res:
continue
if len(res) > 1: # Some queries may want multiple results
raise Exception(
f"When searching for <{deriv}>, found multiple results: {[f.path for f in res]}"
)
derivatives[deriv] = res[0]
return derivatives


def parse_bids_for_age_months(
bids_root: Union[str, Path],
bids_root: str | Path,
subject_id: str,
session_id: Optional[str] = None,
) -> Optional[int]:
session_id: str | None = None,
) -> int | None:
"""
Given a BIDS root, query the BIDS metadata files for participant age, in months.
Expand All @@ -295,8 +354,8 @@ def parse_bids_for_age_months(


def _get_age_from_tsv(
bids_tsv: Path, level: Literal['session', 'participant'], key: str
) -> Optional[int]:
bids_tsv: Path, level: ty.Literal['session', 'participant'], key: str
) -> int | None:
import pandas as pd

df = pd.read_csv(str(bids_tsv), sep='\t')
Expand Down
Empty file.
113 changes: 113 additions & 0 deletions nibabies/utils/tests/test_bids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from __future__ import annotations

import json
import typing as ty
from pathlib import Path

import pytest

from nibabies.utils import bids


def _create_nifti(filename: str) -> str:
import nibabel as nb
import numpy as np

data = np.zeros((4, 4, 4), dtype='int8')
nb.Nifti1Image(data, np.eye(4)).to_filename(filename)
return filename


def _create_bids_dir(root_path: Path):
if not root_path.exists():
root_path.mkdir()
anat_dir = root_path / 'sub-01' / 'anat'
anat_dir.mkdir(parents=True)
_create_nifti(str(anat_dir / 'sub-01_T1w.nii.gz'))
_create_nifti(str(anat_dir / 'sub-01_T2w.nii.gz'))


def _create_bids_derivs(
root_path: Path,
*,
t1w_mask: bool = False,
t1w_aseg: bool = False,
t2w_mask: bool = False,
t2w_aseg: bool = False,
):
if not root_path.exists():
root_path.mkdir()
(root_path / 'dataset_description.json').write_text(
json.dumps(
{'Name': 'Derivatives Test', 'BIDSVersion': '1.8.0', 'DatasetType': 'derivative'}
)
)
anat_dir = root_path / 'sub-01' / 'anat'
anat_dir.mkdir(parents=True)

def _create_deriv(name: str, modality: ty.Literal['t1w', 't2w']):
if modality == 't1w':
reference = 'sub-01/anat/sub-01_T1w.nii.gz'
elif modality == 't2w':
reference = 'sub-01/anat/sub-01_T2w.nii.gz'

_create_nifti(str((anat_dir / name).with_suffix('.nii.gz')))
(anat_dir / name).with_suffix('.json').write_text(
json.dumps({'SpatialReference': reference})
)

if t1w_mask:
_create_deriv('sub-01_space-T1w_desc-brain_mask', 't1w')
if t1w_aseg:
_create_deriv('sub-01_space-T1w_desc-aseg_dseg', 't1w')
if t2w_mask:
_create_deriv('sub-01_space-T2w_desc-brain_mask', 't2w')
if t2w_aseg:
_create_deriv('sub-01_space-T2w_desc-aseg_dseg', 't2w')


@pytest.mark.parametrize(
't1w_mask,t1w_aseg,t2w_mask,t2w_aseg,mask,aseg',
[
(True, True, False, False, 't1w_mask', 't1w_aseg'),
(True, True, True, True, 't1w_mask', 't1w_aseg'),
(False, False, True, True, 't2w_mask', 't2w_aseg'),
(True, False, False, True, 't1w_mask', 't2w_aseg'),
(False, False, False, False, None, None),
],
)
def test_derivatives(
tmp_path: Path,
t1w_mask: bool,
t1w_aseg: bool,
t2w_mask: bool,
t2w_aseg: bool,
mask: str | None,
aseg: str | None,
):
bids_dir = tmp_path / 'bids'
_create_bids_dir(bids_dir)
deriv_dir = tmp_path / 'derivatives'
_create_bids_derivs(
deriv_dir, t1w_mask=t1w_mask, t1w_aseg=t1w_aseg, t2w_mask=t2w_mask, t2w_aseg=t2w_aseg
)

derivatives = bids.Derivatives(bids_dir)
assert derivatives.mask is None
assert derivatives.t1w_mask is None
assert derivatives.t2w_mask is None
assert derivatives.aseg is None
assert derivatives.t1w_aseg is None
assert derivatives.t2w_aseg is None

derivatives.populate(deriv_dir, subject_id='01')
if mask:
assert derivatives.mask == getattr(derivatives, mask)
assert derivatives.references[mask]
else:
assert derivatives.mask is None
if aseg:
assert derivatives.aseg == getattr(derivatives, aseg)
assert derivatives.references[aseg]
else:
assert derivatives.aseg == None
53 changes: 0 additions & 53 deletions nibabies/utils/validation.py

This file was deleted.

Loading

0 comments on commit 131f0d5

Please sign in to comment.