From f0d92c39ffec4615e5037af486554b123ed8df8e Mon Sep 17 00:00:00 2001 From: Oliver Hinds Date: Wed, 30 Aug 2023 14:10:31 -0700 Subject: [PATCH] Unbatch before augmentation if necessary. --- nobrainer/dataset.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/nobrainer/dataset.py b/nobrainer/dataset.py index 01f5f9c4..cd3a983d 100644 --- a/nobrainer/dataset.py +++ b/nobrainer/dataset.py @@ -260,6 +260,11 @@ def normalize(self, normalizer): return self.map(lambda x, y: (normalizer(x), y)) def augment(self, augment_steps, num_parallel_calls=AUTOTUNE): + batch_size = None + if len(self.dataset.shape) > 4: + batch_size = self.batch_size + self.dataset = self.dataset.unbatch() + for transform, kwargs in augment_steps: self.map( lambda x, y: tf.cond( @@ -270,6 +275,9 @@ def augment(self, augment_steps, num_parallel_calls=AUTOTUNE): num_parallel_calls=num_parallel_calls, ) + if batch_size: + self.batch(batch_size) + return self def block(self, block_shape, num_parallel_calls=AUTOTUNE):