diff --git a/nobrainer/dataset.py b/nobrainer/dataset.py index 9a54f686..ff5eef18 100644 --- a/nobrainer/dataset.py +++ b/nobrainer/dataset.py @@ -75,6 +75,7 @@ def from_tfrecords( file_pattern=None, n_volumes=None, volume_shape=None, + block_shape=None, scalar_labels=False, n_classes=1, num_parallel_calls=1, @@ -123,7 +124,16 @@ def from_tfrecords( num_parallel_calls=num_parallel_calls, ) dataset = dataset.map(map_func=parse_fn, num_parallel_calls=num_parallel_calls) - return cls(dataset, n_volumes, volume_shape, n_classes) + ds_obj = cls(dataset, n_volumes, volume_shape, n_classes) + + if block_shape: + ds_obj.block(block_shape) + ds_obj.map_labels() + + # TODO automatically determine batch size + ds_obj.batch(1) + + return ds_obj @classmethod def from_files( @@ -137,6 +147,7 @@ def from_files( num_parallel_calls=1, eval_size=0.1, n_classes=1, + block_shape=None, ): """Create Nobrainer datasets from data filepaths: List(str), list of paths to individual input data files. @@ -195,6 +206,7 @@ def from_files( volume_shape, scalar_labels=scalar_labels, n_classes=n_classes, + block_shape=block_shape, num_parallel_calls=num_parallel_calls, ) ds_eval = None @@ -206,6 +218,7 @@ def from_files( volume_shape, scalar_labels=scalar_labels, n_classes=n_classes, + block_shape=block_shape, num_parallel_calls=num_parallel_calls, ) return ds_train, ds_eval @@ -277,17 +290,19 @@ def _f(x, y): self.dataset = self.dataset.unbatch() return self - def map_labels(self, n_classes, label_mapping=None): - if n_classes < 1: + def map_labels(self, label_mapping=None): + if self.n_classes < 1: raise ValueError("n_classes must be > 0.") - if n_classes == 1: + + if label_mapping is not None: + self.map( + lambda x, y: (x, replace(y, label_mapping=label_mapping))) + + if self.n_classes == 1: self.map(lambda x, y: (x, tf.expand_dims(binarize(y), -1))) - elif n_classes == 2: + elif self.n_classes == 2: self.map(lambda x, y: (x, tf.one_hot(binarize(y), n_classes))) - elif n_classes > 2: - if label_mapping is not None: - self.map( - lambda x, y: (x, replace(y, label_mapping=label_mapping))) + elif self.n_classes > 2: self.map(lambda x, y: (x, tf.one_hot(y, n_classes))) return self