diff --git a/nobrainer/dataset.py b/nobrainer/dataset.py index 01f5f9c4..cd3a983d 100644 --- a/nobrainer/dataset.py +++ b/nobrainer/dataset.py @@ -260,6 +260,11 @@ def normalize(self, normalizer): return self.map(lambda x, y: (normalizer(x), y)) def augment(self, augment_steps, num_parallel_calls=AUTOTUNE): + batch_size = None + if len(self.dataset.shape) > 4: + batch_size = self.batch_size + self.dataset = self.dataset.unbatch() + for transform, kwargs in augment_steps: self.map( lambda x, y: tf.cond( @@ -270,6 +275,9 @@ def augment(self, augment_steps, num_parallel_calls=AUTOTUNE): num_parallel_calls=num_parallel_calls, ) + if batch_size: + self.batch(batch_size) + return self def block(self, block_shape, num_parallel_calls=AUTOTUNE):