From 984f2b0bdf3b984c1e7985079bb766cb6ed0e3d7 Mon Sep 17 00:00:00 2001 From: Richard Lane Date: Mon, 21 Oct 2024 16:36:30 +0100 Subject: [PATCH] skel for loading transforms --- fishjaw/model/data.py | 25 ++++++++++--------------- scripts/explore_hyperparams.py | 3 --- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/fishjaw/model/data.py b/fishjaw/model/data.py index aa0c4a5..47687f2 100644 --- a/fishjaw/model/data.py +++ b/fishjaw/model/data.py @@ -4,7 +4,6 @@ """ import pathlib -from typing import Union import torch import torch.utils @@ -235,26 +234,22 @@ def test_loader( ) +def _load_transform(transform: dict) -> tio.transforms.Transform: + """ + Load a transform from the configuration, which should be provided as a dict of {"name": {"arg1": value1, ...}} + + """ + if not isinstance(transform, dict): + raise ValueError(f"Transform {transform} is not a dict") + + def _transforms(config: dict) -> tio.transforms.Transform: """ Define the transforms to apply to the training data """ return tio.Compose( - [ - tio.RandomFlip(axes=(0, 1, 2), flip_probability=0.5), - tio.RandomAffine( - p=0.25, - degrees=10, - scales=0.2, - ), - # tio.RandomBlur(p=0.3), - # tio.RandomBiasField(0.4, p=0.5), - # tio.RandomNoise(0.1, 0.01, p=0.25), - # tio.RandomGamma((-0.3, 0.3), p=0.25), - # tio.ZNormalization(), - # tio.RescaleIntensity(percentiles=(0.5, 99.5)), - ] + [_load_transform(transform_fcn) for transform_fcn in config["transforms"]] ) diff --git a/scripts/explore_hyperparams.py b/scripts/explore_hyperparams.py index f7de971..1aeddcd 100644 --- a/scripts/explore_hyperparams.py +++ b/scripts/explore_hyperparams.py @@ -268,9 +268,6 @@ def main(*, mode: str, n_steps: int, continue_run: bool): train_subjects, val_subjects, _ = data.read_dicoms_from_disk( example_config, rng, - # Use the same transforms as we would for training - # TODO this should really be in the config - transforms="default", ) # I think this might be doing something slightly wrong - we're getting the data which means,