Skip to content

Commit

Permalink
Block and map labels in construction because of order.
Browse files Browse the repository at this point in the history
  • Loading branch information
ohinds committed Aug 30, 2023
1 parent d89aa8d commit 2165404
Showing 1 changed file with 24 additions and 9 deletions.
33 changes: 24 additions & 9 deletions nobrainer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def from_tfrecords(
file_pattern=None,
n_volumes=None,
volume_shape=None,
block_shape=None,
scalar_labels=False,
n_classes=1,
num_parallel_calls=1,
Expand Down Expand Up @@ -123,7 +124,16 @@ def from_tfrecords(
num_parallel_calls=num_parallel_calls,
)
dataset = dataset.map(map_func=parse_fn, num_parallel_calls=num_parallel_calls)
return cls(dataset, n_volumes, volume_shape, n_classes)
ds_obj = cls(dataset, n_volumes, volume_shape, n_classes)

if block_shape:
ds_obj.block(block_shape)
ds_obj.map_labels()

# TODO automatically determine batch size
ds_obj.batch(1)

return ds_obj

@classmethod
def from_files(
Expand All @@ -137,6 +147,7 @@ def from_files(
num_parallel_calls=1,
eval_size=0.1,
n_classes=1,
block_shape=None,
):
"""Create Nobrainer datasets from data
filepaths: List(str), list of paths to individual input data files.
Expand Down Expand Up @@ -195,6 +206,7 @@ def from_files(
volume_shape,
scalar_labels=scalar_labels,
n_classes=n_classes,
block_shape=block_shape,
num_parallel_calls=num_parallel_calls,
)
ds_eval = None
Expand All @@ -206,6 +218,7 @@ def from_files(
volume_shape,
scalar_labels=scalar_labels,
n_classes=n_classes,
block_shape=block_shape,
num_parallel_calls=num_parallel_calls,
)
return ds_train, ds_eval
Expand Down Expand Up @@ -277,17 +290,19 @@ def _f(x, y):
self.dataset = self.dataset.unbatch()
return self

def map_labels(self, n_classes, label_mapping=None):
if n_classes < 1:
def map_labels(self, label_mapping=None):
if self.n_classes < 1:
raise ValueError("n_classes must be > 0.")
if n_classes == 1:

if label_mapping is not None:
self.map(
lambda x, y: (x, replace(y, label_mapping=label_mapping)))

if self.n_classes == 1:
self.map(lambda x, y: (x, tf.expand_dims(binarize(y), -1)))
elif n_classes == 2:
elif self.n_classes == 2:
self.map(lambda x, y: (x, tf.one_hot(binarize(y), n_classes)))
elif n_classes > 2:
if label_mapping is not None:
self.map(
lambda x, y: (x, replace(y, label_mapping=label_mapping)))
elif self.n_classes > 2:
self.map(lambda x, y: (x, tf.one_hot(y, n_classes)))

return self
Expand Down

0 comments on commit 2165404

Please sign in to comment.