Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 30, 2023
1 parent f0d92c3 commit 6fa5812
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 39 deletions.
18 changes: 12 additions & 6 deletions nobrainer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def from_tfrecords(
# Assume all files have same compression type as the first file.
compression_type = "GZIP" if compressed else None
cycle_length = 1 if num_parallel_calls is None else num_parallel_calls
parse_fn = parse_example_fn(volume_shape=volume_shape, scalar_labels=scalar_labels)
parse_fn = parse_example_fn(
volume_shape=volume_shape, scalar_labels=scalar_labels
)

# Determine examples_per_shard from the first TFRecord shard
# Then set block_length to equal the number of examples per shard
Expand Down Expand Up @@ -185,14 +187,16 @@ def from_files(
volume_shape = nb.load(filepaths[0][0]).shape
write(
features_labels=filepaths[:n_train],
filename_template=template.format(intent="train") + "_shard-{shard:03d}.tfrec",
filename_template=template.format(intent="train")
+ "_shard-{shard:03d}.tfrec",
examples_per_shard=shard_size,
processes=num_parallel_calls,
)
if n_eval > 0:
write(
features_labels=filepaths[n_train:],
filename_template=template.format(intent="eval") + "_shard-{shard:03d}.tfrec",
filename_template=template.format(intent="eval")
+ "_shard-{shard:03d}.tfrec",
examples_per_shard=shard_size,
processes=num_parallel_calls,
)
Expand Down Expand Up @@ -239,7 +243,9 @@ def get_steps_per_epoch(self):
def get_n(a, k):
return (a - k) / k + 1

n_blocks = tuple(get_n(aa, kk) for aa, kk in zip(self.volume_shape, self.block_shape))
n_blocks = tuple(
get_n(aa, kk) for aa, kk in zip(self.volume_shape, self.block_shape)
)

for n in n_blocks:
if not n.is_integer() or n < 1:
Expand Down Expand Up @@ -287,6 +293,7 @@ def block(self, block_shape, num_parallel_calls=AUTOTUNE):
num_parallel_calls=num_parallel_calls,
)
else:

def _f(x, y):
x = to_blocks(x, block_shape)
n_blocks = x.shape[0]
Expand All @@ -303,8 +310,7 @@ def map_labels(self, label_mapping=None):
raise ValueError("n_classes must be > 0.")

if label_mapping is not None:
self.map(
lambda x, y: (x, replace(y, label_mapping=label_mapping)))
self.map(lambda x, y: (x, replace(y, label_mapping=label_mapping)))

if self.n_classes == 1:
self.map(lambda x, y: (x, tf.expand_dims(binarize(y), -1)))
Expand Down
82 changes: 49 additions & 33 deletions nobrainer/tests/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def test_get_dataset_maintains_order(

y_orig = np.array([y for _, y in filepaths])
y_from_dset = (
np.concatenate([y for _, y in dset.dataset.as_numpy_iterator()]).flatten().astype(int)
np.concatenate([y for _, y in dset.dataset.as_numpy_iterator()])
.flatten()
.astype(int)
)
assert_array_equal(y_orig, y_from_dset)
shutil.rmtree(temp_dir)
Expand Down Expand Up @@ -118,13 +120,13 @@ def test_get_dataset_errors_augmentation():
n_volumes=10,
volume_shape=(256, 256, 256),
n_classes=1,
).augment=([
).augment = [
(
intensity_transforms.addGaussianNoise,
{"noise_mean": 0.1, "noise_std": 0.5},
),
(spatial_transforms.randomflip_leftright),
])
]
shutil.rmtree(temp_dir)


Expand All @@ -139,43 +141,57 @@ def test_get_steps_per_epoch():
temp_dir = tempfile.mkdtemp()
nifti_paths = create_dummy_niftis(volume_shape, 10, temp_dir)
filepaths = [(x, i) for i, x in enumerate(nifti_paths)]
file_pattern = write_tfrecs(
filepaths, temp_dir, examples_per_shard=1
file_pattern = write_tfrecs(filepaths, temp_dir, examples_per_shard=1)
dset = (
dataset.Dataset.from_tfrecords(
file_pattern=file_pattern.replace("*", "000"),
n_volumes=1,
volume_shape=volume_shape,
scalar_labels=True,
n_classes=1,
)
.block(block_shape=(64, 64, 64))
.batch(1)
)
dset = dataset.Dataset.from_tfrecords(
file_pattern=file_pattern.replace('*', '000'),
n_volumes=1,
volume_shape=volume_shape,
scalar_labels=True,
n_classes=1,
).block(block_shape=(64, 64, 64)).batch(1)
assert dset.get_steps_per_epoch() == 64

dset = dataset.Dataset.from_tfrecords(
file_pattern=file_pattern.replace('*', '000'),
n_volumes=1,
volume_shape=volume_shape,
scalar_labels=True,
n_classes=1,
).block(block_shape=(64, 64, 64)).batch(64)
dset = (
dataset.Dataset.from_tfrecords(
file_pattern=file_pattern.replace("*", "000"),
n_volumes=1,
volume_shape=volume_shape,
scalar_labels=True,
n_classes=1,
)
.block(block_shape=(64, 64, 64))
.batch(64)
)
assert dset.get_steps_per_epoch() == 1

dset = dataset.Dataset.from_tfrecords(
file_pattern=file_pattern.replace('*', '000'),
n_volumes=1,
volume_shape=volume_shape,
scalar_labels=True,
n_classes=1,
).block(block_shape=(64, 64, 64)).batch(63)
dset = (
dataset.Dataset.from_tfrecords(
file_pattern=file_pattern.replace("*", "000"),
n_volumes=1,
volume_shape=volume_shape,
scalar_labels=True,
n_classes=1,
)
.block(block_shape=(64, 64, 64))
.batch(63)
)
assert dset.get_steps_per_epoch() == 2

dset = dataset.Dataset.from_tfrecords(
file_pattern=file_pattern,
n_volumes=10,
volume_shape=volume_shape,
scalar_labels=True,
n_classes=1,
).block(block_shape=(128, 128, 128)).batch(4)
dset = (
dataset.Dataset.from_tfrecords(
file_pattern=file_pattern,
n_volumes=10,
volume_shape=volume_shape,
scalar_labels=True,
n_classes=1,
)
.block(block_shape=(128, 128, 128))
.batch(4)
)
assert dset.get_steps_per_epoch() == 20

shutil.rmtree(temp_dir)

0 comments on commit 6fa5812

Please sign in to comment.