Skip to content

Commit

Permalink
Unbatch before augmentation if necessary.
Browse files Browse the repository at this point in the history
  • Loading branch information
ohinds committed Aug 30, 2023
1 parent 30fc9d2 commit f0d92c3
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions nobrainer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,11 @@ def normalize(self, normalizer):
return self.map(lambda x, y: (normalizer(x), y))

def augment(self, augment_steps, num_parallel_calls=AUTOTUNE):
batch_size = None
if len(self.dataset.shape) > 4:
batch_size = self.batch_size
self.dataset = self.dataset.unbatch()

for transform, kwargs in augment_steps:
self.map(
lambda x, y: tf.cond(
Expand All @@ -270,6 +275,9 @@ def augment(self, augment_steps, num_parallel_calls=AUTOTUNE):
num_parallel_calls=num_parallel_calls,
)

if batch_size:
self.batch(batch_size)

return self

def block(self, block_shape, num_parallel_calls=AUTOTUNE):
Expand Down

0 comments on commit f0d92c3

Please sign in to comment.