diff --git a/nobrainer/dataset.py b/nobrainer/dataset.py index cd3a983d..368ac739 100644 --- a/nobrainer/dataset.py +++ b/nobrainer/dataset.py @@ -101,7 +101,9 @@ def from_tfrecords( # Assume all files have same compression type as the first file. compression_type = "GZIP" if compressed else None cycle_length = 1 if num_parallel_calls is None else num_parallel_calls - parse_fn = parse_example_fn(volume_shape=volume_shape, scalar_labels=scalar_labels) + parse_fn = parse_example_fn( + volume_shape=volume_shape, scalar_labels=scalar_labels + ) # Determine examples_per_shard from the first TFRecord shard # Then set block_length to equal the number of examples per shard @@ -185,14 +187,16 @@ def from_files( volume_shape = nb.load(filepaths[0][0]).shape write( features_labels=filepaths[:n_train], - filename_template=template.format(intent="train") + "_shard-{shard:03d}.tfrec", + filename_template=template.format(intent="train") + + "_shard-{shard:03d}.tfrec", examples_per_shard=shard_size, processes=num_parallel_calls, ) if n_eval > 0: write( features_labels=filepaths[n_train:], - filename_template=template.format(intent="eval") + "_shard-{shard:03d}.tfrec", + filename_template=template.format(intent="eval") + + "_shard-{shard:03d}.tfrec", examples_per_shard=shard_size, processes=num_parallel_calls, ) @@ -239,7 +243,9 @@ def get_steps_per_epoch(self): def get_n(a, k): return (a - k) / k + 1 - n_blocks = tuple(get_n(aa, kk) for aa, kk in zip(self.volume_shape, self.block_shape)) + n_blocks = tuple( + get_n(aa, kk) for aa, kk in zip(self.volume_shape, self.block_shape) + ) for n in n_blocks: if not n.is_integer() or n < 1: @@ -287,6 +293,7 @@ def block(self, block_shape, num_parallel_calls=AUTOTUNE): num_parallel_calls=num_parallel_calls, ) else: + def _f(x, y): x = to_blocks(x, block_shape) n_blocks = x.shape[0] @@ -303,8 +310,7 @@ def map_labels(self, label_mapping=None): raise ValueError("n_classes must be > 0.") if label_mapping is not None: - self.map( - lambda x, y: (x, replace(y, label_mapping=label_mapping))) + self.map(lambda x, y: (x, replace(y, label_mapping=label_mapping))) if self.n_classes == 1: self.map(lambda x, y: (x, tf.expand_dims(binarize(y), -1))) diff --git a/nobrainer/tests/dataset_test.py b/nobrainer/tests/dataset_test.py index 503c9141..a9a0071a 100644 --- a/nobrainer/tests/dataset_test.py +++ b/nobrainer/tests/dataset_test.py @@ -52,7 +52,9 @@ def test_get_dataset_maintains_order( y_orig = np.array([y for _, y in filepaths]) y_from_dset = ( - np.concatenate([y for _, y in dset.dataset.as_numpy_iterator()]).flatten().astype(int) + np.concatenate([y for _, y in dset.dataset.as_numpy_iterator()]) + .flatten() + .astype(int) ) assert_array_equal(y_orig, y_from_dset) shutil.rmtree(temp_dir) @@ -118,13 +120,13 @@ def test_get_dataset_errors_augmentation(): n_volumes=10, volume_shape=(256, 256, 256), n_classes=1, - ).augment=([ + ).augment = [ ( intensity_transforms.addGaussianNoise, {"noise_mean": 0.1, "noise_std": 0.5}, ), (spatial_transforms.randomflip_leftright), - ]) + ] shutil.rmtree(temp_dir) @@ -139,43 +141,57 @@ def test_get_steps_per_epoch(): temp_dir = tempfile.mkdtemp() nifti_paths = create_dummy_niftis(volume_shape, 10, temp_dir) filepaths = [(x, i) for i, x in enumerate(nifti_paths)] - file_pattern = write_tfrecs( - filepaths, temp_dir, examples_per_shard=1 + file_pattern = write_tfrecs(filepaths, temp_dir, examples_per_shard=1) + dset = ( + dataset.Dataset.from_tfrecords( + file_pattern=file_pattern.replace("*", "000"), + n_volumes=1, + volume_shape=volume_shape, + scalar_labels=True, + n_classes=1, + ) + .block(block_shape=(64, 64, 64)) + .batch(1) ) - dset = dataset.Dataset.from_tfrecords( - file_pattern=file_pattern.replace('*', '000'), - n_volumes=1, - volume_shape=volume_shape, - scalar_labels=True, - n_classes=1, - ).block(block_shape=(64, 64, 64)).batch(1) assert dset.get_steps_per_epoch() == 64 - dset = dataset.Dataset.from_tfrecords( - file_pattern=file_pattern.replace('*', '000'), - n_volumes=1, - volume_shape=volume_shape, - scalar_labels=True, - n_classes=1, - ).block(block_shape=(64, 64, 64)).batch(64) + dset = ( + dataset.Dataset.from_tfrecords( + file_pattern=file_pattern.replace("*", "000"), + n_volumes=1, + volume_shape=volume_shape, + scalar_labels=True, + n_classes=1, + ) + .block(block_shape=(64, 64, 64)) + .batch(64) + ) assert dset.get_steps_per_epoch() == 1 - dset = dataset.Dataset.from_tfrecords( - file_pattern=file_pattern.replace('*', '000'), - n_volumes=1, - volume_shape=volume_shape, - scalar_labels=True, - n_classes=1, - ).block(block_shape=(64, 64, 64)).batch(63) + dset = ( + dataset.Dataset.from_tfrecords( + file_pattern=file_pattern.replace("*", "000"), + n_volumes=1, + volume_shape=volume_shape, + scalar_labels=True, + n_classes=1, + ) + .block(block_shape=(64, 64, 64)) + .batch(63) + ) assert dset.get_steps_per_epoch() == 2 - dset = dataset.Dataset.from_tfrecords( - file_pattern=file_pattern, - n_volumes=10, - volume_shape=volume_shape, - scalar_labels=True, - n_classes=1, - ).block(block_shape=(128, 128, 128)).batch(4) + dset = ( + dataset.Dataset.from_tfrecords( + file_pattern=file_pattern, + n_volumes=10, + volume_shape=volume_shape, + scalar_labels=True, + n_classes=1, + ) + .block(block_shape=(128, 128, 128)) + .batch(4) + ) assert dset.get_steps_per_epoch() == 20 shutil.rmtree(temp_dir)