Skip to content

Commit

Permalink
get the transforms from the config file
Browse files Browse the repository at this point in the history
  • Loading branch information
Richard Lane committed Oct 21, 2024
1 parent 56632fc commit 9dcc6d2
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 22 deletions.
16 changes: 9 additions & 7 deletions fishjaw/model/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tqdm import tqdm

from ..images import io, transform
from ..util import files
from ..util import files, util


def get_patch_size(config: dict) -> tuple[int, int, int]:
Expand Down Expand Up @@ -234,22 +234,24 @@ def test_loader(
)


def _load_transform(transform: dict) -> tio.transforms.Transform:
def _load_transform(transform_name: str, args: dict) -> tio.transforms.Transform:
"""
Load a transform from the configuration, which should be provided as a dict of {"name": {"arg1": value1, ...}}
"""
if not isinstance(transform, dict):
raise ValueError(f"Transform {transform} is not a dict")
return util.load_class(transform_name)(**args)


def _transforms(config: dict) -> tio.transforms.Transform:
def _transforms(transform_dict: dict) -> tio.transforms.Transform:
"""
Define the transforms to apply to the training data
"""
return tio.Compose(
[_load_transform(transform_fcn) for transform_fcn in config["transforms"]]
[
_load_transform(transform_name, args)
for transform_name, args in transform_dict.items()
]
)


Expand Down Expand Up @@ -308,7 +310,7 @@ def read_dicoms_from_disk(
print(f"Test: {test_idx=}")

train_subjects = tio.SubjectsDataset(
[subjects[i] for i in train_idx], transform=_transforms()
[subjects[i] for i in train_idx], transform=_transforms(config["transforms"])
)
val_subjects = tio.SubjectsDataset([subjects[i] for i in val_idx])
test_subject = subjects[test_idx]
Expand Down
4 changes: 2 additions & 2 deletions fishjaw/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ def config() -> dict:

def load_class(name: str) -> type:
"""
Load a class from a string.
Load a class from a module given a string.
:param name: the name of the class to load. Should be in the format module.class,
where module can also contain "."s
where module can also contain "."s (e.g. module.submodule.class)
:returns: the class object
"""
Expand Down
26 changes: 13 additions & 13 deletions userconf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,20 @@ lr_lambda: 0.9999 # Exponential decay factor (multiplicative with each epoch)

# Options should be passed
transforms:
tio.RandomFlip:
- axes: [0, 1, 2]
- flip_probability: 0.5
tio.RandomAffine:
- p: 0.25
- degrees: 10
- scales: 0.2
torchio.RandomFlip:
axes: [0, 1, 2]
flip_probability: 0.5
torchio.RandomAffine:
p: 0.25
degrees: 10
scales: 0.2
# Other options might be
# tio.RandomBlur(p=0.3),
# tio.RandomBiasField(0.4, p=0.5),
# tio.RandomNoise(0.1, 0.01, p=0.25),
# tio.RandomGamma((-0.3, 0.3), p=0.25),
# tio.ZNormalization(),
# tio.RescaleIntensity(percentiles=(0.5, 99.5)),
# torchio.RandomBlur(p=0.3),
# torchio.RandomBiasField(0.4, p=0.5),
# torchio.RandomNoise(0.1, 0.01, p=0.25),
# torchio.RandomGamma((-0.3, 0.3), p=0.25),
# torchio.ZNormalization(),
# torchio.RescaleIntensity(percentiles=(0.5, 99.5)),

model_params:
model_name: "monai.networks.nets.AttentionUnet"
Expand Down

0 comments on commit 9dcc6d2

Please sign in to comment.