Skip to content

Commit

Permalink
Don't map scalar labels
Browse files Browse the repository at this point in the history
  • Loading branch information
ohinds committed Aug 31, 2023
1 parent bfad67c commit e5f46aa
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions nobrainer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def from_tfrecords(

if block_shape:
ds_obj.block(block_shape)
ds_obj.map_labels()

if not scalar_labels:
ds_obj.map_labels()
# TODO automatically determine batch size
ds_obj.batch(1)

Expand Down Expand Up @@ -237,7 +237,7 @@ def block_shape(self):

@property
def scalar_labels(self):
return len(self.dataset.element_spec[1].shape) == 1
return _labels_all_scalar([y for _, y in self.dataset.as_numpy_iterator()])

def get_steps_per_epoch(self):
def get_n(a, k):
Expand Down Expand Up @@ -326,6 +326,8 @@ def batch(self, batch_size):
# Otherwise, assume that the channels are already in the features.
if len(self.volume_shape) == 3:
self.map(lambda x, y: (tf.expand_dims(x, -1), y))
elif len(self.dataset.element_spec[0].shape) > 4:
self.dataset = self.dataset.unbatch()

# Prefetch data to overlap data production with data consumption. The
# TensorFlow documentation suggests prefetching `batch_size` elements.
Expand Down

0 comments on commit e5f46aa

Please sign in to comment.