Skip to content

Commit

Permalink
Don't rely on initial volume shape to determine current dimensions.
Browse files Browse the repository at this point in the history
  • Loading branch information
ohinds committed Aug 31, 2023
1 parent d07ab3a commit 8f09852
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 44 deletions.
2 changes: 1 addition & 1 deletion nobrainer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def map_labels(self, label_mapping=None):
def batch(self, batch_size):
# If volume_shape is only three dims, add grayscale channel to features.
# Otherwise, assume that the channels are already in the features.
if len(self.volume_shape) == 3:
if len(self.dataset.element_spec[0].shape) == 3:
self.map(lambda x, y: (tf.expand_dims(x, -1), y))
elif len(self.dataset.element_spec[0].shape) > 4:
self.dataset = self.dataset.unbatch()
Expand Down
74 changes: 31 additions & 43 deletions nobrainer/tests/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,56 +142,44 @@ def test_get_steps_per_epoch():
nifti_paths = create_dummy_niftis(volume_shape, 10, temp_dir)
filepaths = [(x, i) for i, x in enumerate(nifti_paths)]
file_pattern = write_tfrecs(filepaths, temp_dir, examples_per_shard=1)
dset = (
dataset.Dataset.from_tfrecords(
file_pattern=file_pattern.replace("*", "000"),
n_volumes=1,
volume_shape=volume_shape,
scalar_labels=True,
n_classes=1,
)
.block(block_shape=(64, 64, 64))
.batch(1)
dset = dataset.Dataset.from_tfrecords(
file_pattern=file_pattern.replace("*", "000"),
n_volumes=1,
volume_shape=volume_shape,
block_shape=(64, 64, 64),
scalar_labels=True,
n_classes=1,
)
assert dset.get_steps_per_epoch() == 64

dset = (
dataset.Dataset.from_tfrecords(
file_pattern=file_pattern.replace("*", "000"),
n_volumes=1,
volume_shape=volume_shape,
scalar_labels=True,
n_classes=1,
)
.block(block_shape=(64, 64, 64))
.batch(64)
)
dset = dataset.Dataset.from_tfrecords(
file_pattern=file_pattern.replace("*", "000"),
n_volumes=1,
volume_shape=volume_shape,
block_shape=(64, 64, 64),
scalar_labels=True,
n_classes=1,
).batch(64)
assert dset.get_steps_per_epoch() == 1

dset = (
dataset.Dataset.from_tfrecords(
file_pattern=file_pattern.replace("*", "000"),
n_volumes=1,
volume_shape=volume_shape,
scalar_labels=True,
n_classes=1,
)
.block(block_shape=(64, 64, 64))
.batch(63)
)
dset = dataset.Dataset.from_tfrecords(
file_pattern=file_pattern.replace("*", "000"),
n_volumes=1,
volume_shape=volume_shape,
block_shape=(64, 64, 64),
scalar_labels=True,
n_classes=1,
).batch(63)
assert dset.get_steps_per_epoch() == 2

dset = (
dataset.Dataset.from_tfrecords(
file_pattern=file_pattern,
n_volumes=10,
volume_shape=volume_shape,
scalar_labels=True,
n_classes=1,
)
.block(block_shape=(128, 128, 128))
.batch(4)
)
dset = dataset.Dataset.from_tfrecords(
file_pattern=file_pattern.replace("*", "000"),
n_volumes=1,
volume_shape=volume_shape,
block_shape=(64, 64, 64),
scalar_labels=True,
n_classes=1,
).batch(4)
assert dset.get_steps_per_epoch() == 20

shutil.rmtree(temp_dir)

0 comments on commit 8f09852

Please sign in to comment.