From 9dcc6d2172df9eb3cd8bd4d33f5199680cbfed31 Mon Sep 17 00:00:00 2001 From: Richard Lane Date: Mon, 21 Oct 2024 17:16:00 +0100 Subject: [PATCH] get the transforms from the config file --- fishjaw/model/data.py | 16 +++++++++------- fishjaw/util/util.py | 4 ++-- userconf.yml | 26 +++++++++++++------------- 3 files changed, 24 insertions(+), 22 deletions(-) diff --git a/fishjaw/model/data.py b/fishjaw/model/data.py index 734c104..5c064f3 100644 --- a/fishjaw/model/data.py +++ b/fishjaw/model/data.py @@ -12,7 +12,7 @@ from tqdm import tqdm from ..images import io, transform -from ..util import files +from ..util import files, util def get_patch_size(config: dict) -> tuple[int, int, int]: @@ -234,22 +234,24 @@ def test_loader( ) -def _load_transform(transform: dict) -> tio.transforms.Transform: +def _load_transform(transform_name: str, args: dict) -> tio.transforms.Transform: """ Load a transform from the configuration, which should be provided as a dict of {"name": {"arg1": value1, ...}} """ - if not isinstance(transform, dict): - raise ValueError(f"Transform {transform} is not a dict") + return util.load_class(transform_name)(**args) -def _transforms(config: dict) -> tio.transforms.Transform: +def _transforms(transform_dict: dict) -> tio.transforms.Transform: """ Define the transforms to apply to the training data """ return tio.Compose( - [_load_transform(transform_fcn) for transform_fcn in config["transforms"]] + [ + _load_transform(transform_name, args) + for transform_name, args in transform_dict.items() + ] ) @@ -308,7 +310,7 @@ def read_dicoms_from_disk( print(f"Test: {test_idx=}") train_subjects = tio.SubjectsDataset( - [subjects[i] for i in train_idx], transform=_transforms() + [subjects[i] for i in train_idx], transform=_transforms(config["transforms"]) ) val_subjects = tio.SubjectsDataset([subjects[i] for i in val_idx]) test_subject = subjects[test_idx] diff --git a/fishjaw/util/util.py b/fishjaw/util/util.py index ba1d2ec..4dffdc8 100644 --- a/fishjaw/util/util.py +++ b/fishjaw/util/util.py @@ -69,10 +69,10 @@ def config() -> dict: def load_class(name: str) -> type: """ - Load a class from a string. + Load a class from a module given a string. :param name: the name of the class to load. Should be in the format module.class, - where module can also contain "."s + where module can also contain "."s (e.g. module.submodule.class) :returns: the class object """ diff --git a/userconf.yml b/userconf.yml index 6d793b7..8e319ff 100644 --- a/userconf.yml +++ b/userconf.yml @@ -40,20 +40,20 @@ lr_lambda: 0.9999 # Exponential decay factor (multiplicative with each epoch) # Options should be passed transforms: - tio.RandomFlip: - - axes: [0, 1, 2] - - flip_probability: 0.5 - tio.RandomAffine: - - p: 0.25 - - degrees: 10 - - scales: 0.2 + torchio.RandomFlip: + axes: [0, 1, 2] + flip_probability: 0.5 + torchio.RandomAffine: + p: 0.25 + degrees: 10 + scales: 0.2 # Other options might be -# tio.RandomBlur(p=0.3), -# tio.RandomBiasField(0.4, p=0.5), -# tio.RandomNoise(0.1, 0.01, p=0.25), -# tio.RandomGamma((-0.3, 0.3), p=0.25), -# tio.ZNormalization(), -# tio.RescaleIntensity(percentiles=(0.5, 99.5)), +# torchio.RandomBlur(p=0.3), +# torchio.RandomBiasField(0.4, p=0.5), +# torchio.RandomNoise(0.1, 0.01, p=0.25), +# torchio.RandomGamma((-0.3, 0.3), p=0.25), +# torchio.ZNormalization(), +# torchio.RescaleIntensity(percentiles=(0.5, 99.5)), model_params: model_name: "monai.networks.nets.AttentionUnet"