diff --git a/nobrainer/processing/generation.py b/nobrainer/processing/generation.py index b729e68c..51b46625 100644 --- a/nobrainer/processing/generation.py +++ b/nobrainer/processing/generation.py @@ -152,12 +152,12 @@ def _compile(): num_parallel_calls=num_parallel_calls, volume_shape=(resolution, resolution, resolution), n_classes=1, - scalar_labels=True + scalar_labels=True, ) n_epochs = info.get("epochs") or epochs - dataset.batch(batch_size) \ - .normalize(info.get("normalizer") or normalizer) \ - .repeat(n_epochs) + dataset.batch(batch_size).normalize( + info.get("normalizer") or normalizer + ).repeat(n_epochs) with self.strategy.scope(): # grow the networks by one (2^x) resolution