diff --git a/fishjaw/model/data.py b/fishjaw/model/data.py index aa0c4a5..47687f2 100644 --- a/fishjaw/model/data.py +++ b/fishjaw/model/data.py @@ -4,7 +4,6 @@ """ import pathlib -from typing import Union import torch import torch.utils @@ -235,26 +234,22 @@ def test_loader( ) +def _load_transform(transform: dict) -> tio.transforms.Transform: + """ + Load a transform from the configuration, which should be provided as a dict of {"name": {"arg1": value1, ...}} + + """ + if not isinstance(transform, dict): + raise ValueError(f"Transform {transform} is not a dict") + + def _transforms(config: dict) -> tio.transforms.Transform: """ Define the transforms to apply to the training data """ return tio.Compose( - [ - tio.RandomFlip(axes=(0, 1, 2), flip_probability=0.5), - tio.RandomAffine( - p=0.25, - degrees=10, - scales=0.2, - ), - # tio.RandomBlur(p=0.3), - # tio.RandomBiasField(0.4, p=0.5), - # tio.RandomNoise(0.1, 0.01, p=0.25), - # tio.RandomGamma((-0.3, 0.3), p=0.25), - # tio.ZNormalization(), - # tio.RescaleIntensity(percentiles=(0.5, 99.5)), - ] + [_load_transform(transform_fcn) for transform_fcn in config["transforms"]] ) diff --git a/scripts/explore_hyperparams.py b/scripts/explore_hyperparams.py index f7de971..1aeddcd 100644 --- a/scripts/explore_hyperparams.py +++ b/scripts/explore_hyperparams.py @@ -268,9 +268,6 @@ def main(*, mode: str, n_steps: int, continue_run: bool): train_subjects, val_subjects, _ = data.read_dicoms_from_disk( example_config, rng, - # Use the same transforms as we would for training - # TODO this should really be in the config - transforms="default", ) # I think this might be doing something slightly wrong - we're getting the data which means,