diff --git a/nobrainer/dataset.py b/nobrainer/dataset.py index 25f68add..df8136ac 100644 --- a/nobrainer/dataset.py +++ b/nobrainer/dataset.py @@ -11,239 +11,11 @@ from .io import _is_gzipped, verify_features_labels from .tfrecord import _labels_all_scalar, parse_example_fn, write -from .volume import binarize, replace, standardize, to_blocks +from .volume import binarize, replace, to_blocks AUTOTUNE = tf.data.experimental.AUTOTUNE -def tfrecord_dataset( - file_pattern, - volume_shape, - shuffle, - scalar_label, - compressed=True, - num_parallel_calls=AUTOTUNE, -): - """Return `tf.data.Dataset` from TFRecord files.""" - dataset = tf.data.Dataset.list_files(file_pattern, shuffle=shuffle) - # Read each of these files as a TFRecordDataset. - # Assume all files have same compression type as the first file. - compression_type = "GZIP" if compressed else None - cycle_length = 1 if num_parallel_calls is None else num_parallel_calls - parse_fn = parse_example_fn(volume_shape=volume_shape, scalar_label=scalar_label) - - if not shuffle: - # Determine examples_per_shard from the first TFRecord shard - # Then set block_length to equal the number of examples per shard - # so that the interleave method does not inadvertently shuffle data. - first_shard = ( - dataset.take(1) - .flat_map( - lambda x: tf.data.TFRecordDataset(x, compression_type=compression_type) - ) - .map(map_func=parse_fn, num_parallel_calls=num_parallel_calls) - ) - block_length = len([0 for _ in first_shard]) - else: - # If the dataset is being shuffled, then we don't care if interleave - # further shuffles that data even further - block_length = None - - dataset = dataset.interleave( - map_func=lambda x: tf.data.TFRecordDataset( - x, compression_type=compression_type - ), - cycle_length=cycle_length, - block_length=block_length, - num_parallel_calls=num_parallel_calls, - ) - dataset = dataset.map(map_func=parse_fn, num_parallel_calls=num_parallel_calls) - return dataset - - -def get_dataset( - file_pattern, - n_classes, - batch_size, - volume_shape, - scalar_label=False, - block_shape=None, - n_epochs=None, - mapping=None, - augment=None, - normalizer=standardize, - shuffle_buffer_size=None, - num_parallel_calls=AUTOTUNE, -): - """Return `tf.data.Dataset` that preprocesses data for training or prediction. - - Labels are preprocessed for binary or multiclass segmentation according to - `n_classes`. - - Parameters - ---------- - file_pattern: str, expression that can be globbed to get TFRecords files - for this dataset. For example 'data/training_*.tfrecords'. - n_classes: int, number of classes to segment. Values of 1 and 2 indicate - binary segmentation (foreground vs background), and values greater than - 2 indicate multiclass segmentation. - batch_size: int, number of elements per batch. - volume_shape: tuple of at least length 3, the shape of every volume in the TFRecords - files. Every volume must have the same shape. - scalar_label: boolean, if `True`, labels are scalars. - block_shape: tuple of at least length 3, the shape of the non-overlapping sub-volumes - to take from the full volumes. If None, do not separate the full volumes - into sub-volumes. Separating into non-overlapping sub-volumes is useful - (sometimes even necessary) to overcome memory limitations depending on - the number of model parameters. - n_epochs: int, number of epochs for the dataset to repeat. If None, the - dataset will be repeated indefinitely. - mapping: dict, mapping to replace label values. Values equal to a key in - the mapping are replaced with the corresponding values in the mapping. - Values not in `mapping.keys()` are replaced with zeros. - augment: None, or list of different transforms in the executable sequence - the corresponding arguments in tuple as e.g.: - [(addGaussianNoise, {'noise_mean':0.1,'noise_std':0.5}), (...)] - normalizer: callable, applies this normalization function when creating the - dataset. to maintain compatibility with prior nobrainer release, this is - set to standardize by default. - shuffle_buffer_size: int, buffer of full volumes to shuffle. If this is not - None, then the list of files found by 'file_pattern' is also shuffled - at every iteration. - num_parallel_calls: int, number of parallel calls to make for data loading - and processing. - - Returns - ------- - `tf.data.Dataset` of features and labels. If block_shape is not None, the - shape of features is `(batch_size, *block_shape, 1)` and the shape of labels - is `(batch_size, *block_shape, n_classes)`. If block_shape is None, then - the shape of features is `(batch_size, *volume_shape, 1)` and the shape of - labels is `(batch_size, *volume_shape, n_classes)`. If `scalar_label` is `True, - the shape of labels is always `(batch_size,)`. - """ - - fs, _, _ = fsspec.get_fs_token_paths(file_pattern) - files = fs.glob(file_pattern) - if not files: - raise ValueError("no files found for pattern '{}'".format(file_pattern)) - - # Create dataset of all TFRecord files. After this point, the dataset will have - # two value per iteration: (feature, label). - shuffle = bool(shuffle_buffer_size) - compressed = _is_gzipped(files[0], filesys=fs) - dataset = tfrecord_dataset( - file_pattern=file_pattern, - volume_shape=volume_shape, - shuffle=shuffle, - scalar_label=scalar_label, - compressed=compressed, - num_parallel_calls=num_parallel_calls, - ) - - if normalizer is not None: - # Standard-score the features. - dataset = dataset.map(lambda x, y: (normalizer(x), y)) - - # Augment examples if requested. - if isinstance(augment, bool): - raise ValueError("Augment no longer supports a boolean expression") - - if augment is not None: - for transform, kwargs in augment: - dataset = dataset.map( - lambda x, y: tf.cond( - tf.random.uniform((1,)) > 0.5, - true_fn=lambda: transform(x, y, **kwargs), - false_fn=lambda: (x, y), - ), - num_parallel_calls=num_parallel_calls, - ) - - # Separate into blocks, if requested. - if block_shape is not None: - if not scalar_label: - dataset = dataset.map( - lambda x, y: (to_blocks(x, block_shape), to_blocks(y, block_shape)), - num_parallel_calls=num_parallel_calls, - ) - # This step is necessary because separating into blocks adds a dimension. - dataset = dataset.unbatch() - if scalar_label: - - def _f(x, y): - x = to_blocks(x, block_shape) - n_blocks = x.shape[0] - y = tf.repeat(y, n_blocks) - return (x, y) - - dataset = dataset.map(_f, num_parallel_calls=num_parallel_calls) - # This step is necessary because separating into blocks adds a dimension. - dataset = dataset.unbatch() - else: - if scalar_label: - dataset = dataset.map(lambda x, y: (x, tf.squeeze(y))) - - # Binarize or replace labels according to mapping. - if not scalar_label: - if n_classes < 1: - raise ValueError("n_classes must be > 0.") - elif n_classes == 1: - dataset = dataset.map(lambda x, y: (x, tf.expand_dims(binarize(y), -1))) - elif n_classes == 2: - dataset = dataset.map(lambda x, y: (x, tf.one_hot(binarize(y), n_classes))) - elif n_classes > 2: - if mapping is not None: - dataset = dataset.map(lambda x, y: (x, replace(y, mapping=mapping))) - dataset = dataset.map(lambda x, y: (x, tf.one_hot(y, n_classes))) - - # If volume_shape is only three dims, add grayscale channel to features. - # Otherwise, assume that the channels are already in the features. - if len(volume_shape) == 3: - dataset = dataset.map(lambda x, y: (tf.expand_dims(x, -1), y)) - - # Prefetch data to overlap data production with data consumption. The - # TensorFlow documentation suggests prefetching `batch_size` elements. - dataset = dataset.prefetch(buffer_size=batch_size) - - # Batch the dataset, so each iteration gives `batch_size` elements. We drop - # the remainder so that when training on multiple GPUs, the batch will - # always be evenly divisible by the number of GPUs. Otherwise, the last - # batch might have fewer than `batch_size` elements and will cause errors. - if batch_size is not None: - dataset = dataset.batch(batch_size=batch_size, drop_remainder=True) - - # Optionally shuffle. We also optionally shuffle the list of files. - # The TensorFlow recommend shuffling and then repeating. - if shuffle_buffer_size: - dataset = dataset.shuffle(buffer_size=shuffle_buffer_size) - - # Repeat the dataset for n_epochs. If n_epochs is None, then repeat - # indefinitely. If n_epochs is 1, then the dataset will only be iterated - # through once. - dataset = dataset.repeat(n_epochs) - - return dataset - - -def get_steps_per_epoch(n_volumes, volume_shape, block_shape, batch_size): - def get_n(a, k): - return (a - k) / k + 1 - - n_blocks = tuple(get_n(aa, kk) for aa, kk in zip(volume_shape, block_shape)) - - for n in n_blocks: - if not n.is_integer() or n < 1: - raise ValueError( - "cannot create non-overlapping blocks with the given parameters." - ) - n_blocks_per_volume = np.prod(n_blocks).astype(int) - - steps = n_blocks_per_volume * n_volumes / batch_size - steps = math.ceil(steps) - return steps - - def write_multi_resolution( paths, tfrecdir=Path(os.getcwd()) / "data", @@ -286,137 +58,298 @@ class Dataset: """Datasets for training, and validation""" def __init__( - self, n_classes, batch_size, block_shape, volume_shape=None, n_epochs: int = 1 + self, + dataset, + n_volumes, + volume_shape, + n_classes, ): - self.n_classes = n_classes + self.dataset = dataset + self.n_volumes = n_volumes self.volume_shape = volume_shape - self.block_shape = block_shape - self.batch_size = batch_size - self.n_epochs = n_epochs + self.n_classes = n_classes + @classmethod def from_tfrecords( - self, - volume_shape, - scalar_labels, - n_volumes, - template="data/data-train_shard-*.tfrec", - augment=None, - shuffle_buffer_size=None, + cls, + file_pattern=None, + n_volumes=None, + volume_shape=None, + block_shape=None, + scalar_labels=False, + n_classes=1, num_parallel_calls=1, ): - """Function to retrieve a saved tf record as a Dataset + """Function to retrieve a saved tf record as a nobrainer Dataset - template: str, the path to which TFRecord files should be written. + file_pattern: str, the path from which TFRecord files should be read. num_parallel_calls: int, number of processes to use for multiprocessing. If None, will use all available processes. """ - self.volume_shape = volume_shape - # replace shard formatting code with * for globbing - dataset = get_dataset( - file_pattern=template, - n_classes=self.n_classes, - batch_size=self.batch_size, - volume_shape=self.volume_shape, - block_shape=self.block_shape, - n_epochs=self.n_epochs, - augment=augment, - shuffle_buffer_size=shuffle_buffer_size, + fs, _, _ = fsspec.get_fs_token_paths(file_pattern) + files = fs.glob(file_pattern) + if not files: + raise ValueError("no files found for pattern '{}'".format(file_pattern)) + + # Create dataset of all TFRecord files. After this point, the dataset will have + # two value per iteration: (feature, label). + compressed = _is_gzipped(files[0], filesys=fs) + dataset = tf.data.Dataset.list_files(file_pattern, shuffle=False) + + # Read each of these files as a TFRecordDataset. + # Assume all files have same compression type as the first file. + compression_type = "GZIP" if compressed else None + cycle_length = 1 if num_parallel_calls is None else num_parallel_calls + parse_fn = parse_example_fn( + volume_shape=volume_shape, scalar_labels=scalar_labels + ) + + # Determine examples_per_shard from the first TFRecord shard + # Then set block_length to equal the number of examples per shard + # so that the interleave method does not inadvertently shuffle data. + first_shard = ( + dataset.take(1) + .flat_map( + lambda x: tf.data.TFRecordDataset(x, compression_type=compression_type) + ) + .map(map_func=parse_fn, num_parallel_calls=num_parallel_calls) + ) + block_length = len([0 for _ in first_shard]) + + dataset = dataset.interleave( + map_func=lambda x: tf.data.TFRecordDataset( + x, compression_type=compression_type + ), + cycle_length=cycle_length, + block_length=block_length, num_parallel_calls=num_parallel_calls, ) - # Add nobrainer specific attributes - dataset.scalar_labels = scalar_labels - dataset.n_volumes = n_volumes - dataset.volume_shape = self.volume_shape - return dataset + dataset = dataset.map(map_func=parse_fn, num_parallel_calls=num_parallel_calls) + ds_obj = cls(dataset, n_volumes, volume_shape, n_classes) + + if block_shape: + ds_obj.block(block_shape) + if not scalar_labels: + ds_obj.map_labels() + # TODO automatically determine batch size + ds_obj.batch(1) + return ds_obj + + @classmethod def from_files( - self, - paths, - eval_size=0.1, - tfrecdir=Path(os.getcwd()) / "data", - shard_size=3, - augment=None, - shuffle_buffer_size=None, - num_parallel_calls=1, + cls, + filepaths, check_shape=True, check_labels_int=False, check_labels_gte_zero=False, + out_tfrec_dir=Path(os.getcwd()) / "data", + shard_size=300, + num_parallel_calls=1, + eval_size=0.1, + n_classes=1, + block_shape=None, ): """Create Nobrainer datasets from data - - template: str, the path to which TFRecord files should be written. A string - formatting key `shard` should be included to indicate the unique TFRecord file - when writing to multiple TFRecord files. For example, - `data_shard-{shard:03d}.tfrec`. - shard_size: int, number of pairs of `(feature, label)` per TFRecord file. + filepaths: List(str), list of paths to individual input data files. check_shape: boolean, if true, validate that the shape of both volumes is equal to 'volume_shape'. check_labels_int: boolean, if true, validate that every labels volume is an integer type or can be safely converted to an integer type. check_labels_gte_zero: boolean, if true, validate that every labels volume has values greater than or equal to zero. + out_tfrec_dir: str, the directory to which TFRecord files should be written. + shard_size: int, number of pairs of `(feature, label)` per TFRecord file. num_parallel_calls: int, number of processes to use for multiprocessing. If None, will use all available processes. + eval_size: float, proportion of the input files to reserve for validation. + n_classes: int, number of output classes """ - # Test that the `filename_template` has a `shard` formatting key. - template = str(Path(tfrecdir) / "data-{intent}") - shard_ext = "shard-{shard:03d}.tfrec" - - Neval = np.ceil(len(paths) * eval_size).astype(int) - Ntrain = len(paths) - Neval + n_eval = np.ceil(len(filepaths) * eval_size).astype(int) + n_train = len(filepaths) - n_eval verify_result = verify_features_labels( - paths, + filepaths, check_shape=check_shape, check_labels_int=check_labels_int, check_labels_gte_zero=check_labels_gte_zero, ) - if len(verify_result) == 0: - Path(tfrecdir).mkdir(exist_ok=True, parents=True) - if self.volume_shape is None: - self.volume_shape = nb.load(paths[0][0]).shape + if len(verify_result) != 0: + raise ValueError( + "Provided filepaths did not pass validation. Please " + "check that they have the same shape, and the " + "targets have appropriate labels" + ) + + Path(out_tfrec_dir).mkdir(exist_ok=True, parents=True) + template = str(Path(out_tfrec_dir) / "data-{intent}") + volume_shape = nb.load(filepaths[0][0]).shape + write( + features_labels=filepaths[:n_train], + filename_template=template.format(intent="train") + + "_shard-{shard:03d}.tfrec", + examples_per_shard=shard_size, + processes=num_parallel_calls, + ) + if n_eval > 0: write( - features_labels=paths[:Ntrain], - filename_template=template.format(intent=f"train_{shard_ext}"), + features_labels=filepaths[n_train:], + filename_template=template.format(intent="eval") + + "_shard-{shard:03d}.tfrec", examples_per_shard=shard_size, processes=num_parallel_calls, ) - if Neval > 0: - write( - features_labels=paths[Ntrain:], - filename_template=template.format(intent=f"eval_{shard_ext}"), - examples_per_shard=shard_size, - processes=num_parallel_calls, - ) - labels = (y for _, y in paths) - scalar_labels = _labels_all_scalar(labels) - # replace shard formatting code with * for globbing - template_train = template.format(intent="train_*.tfrec") - ds_train = self.from_tfrecords( - self.volume_shape, - scalar_labels, - len(paths[:Ntrain]), - template=template_train, - augment=augment, - shuffle_buffer_size=shuffle_buffer_size, + labels = (y for _, y in filepaths) + scalar_labels = _labels_all_scalar(labels) + # replace shard formatting code with * for globbing + template_train = template.format(intent="train_*.tfrec") + ds_train = cls.from_tfrecords( + template_train, + n_train, + volume_shape, + scalar_labels=scalar_labels, + n_classes=n_classes, + block_shape=block_shape, + num_parallel_calls=num_parallel_calls, + ) + ds_eval = None + if n_eval > 0: + template_eval = template.format(intent="eval_*.tfrec") + ds_eval = cls.from_tfrecords( + template_eval, + n_eval, + volume_shape, + scalar_labels=scalar_labels, + n_classes=n_classes, + block_shape=block_shape, num_parallel_calls=num_parallel_calls, ) - ds_eval = None - if Neval > 0: - template_eval = template.format(intent="eval_*.tfrec") - ds_eval = self.from_tfrecords( - self.volume_shape, - scalar_labels, - len(paths[Ntrain:]), - template=template_eval, - augment=None, - shuffle_buffer_size=None, - num_parallel_calls=num_parallel_calls, - ) - return ds_train, ds_eval - raise ValueError( - "Provided paths did not pass validation. Please " - "check that they have the same shape, and the " - "targets have appropriate labels" + return ds_train, ds_eval + + @property + def batch_size(self): + return self.dataset.element_spec[0].shape[0] + + @property + def block_shape(self): + return tuple(self.dataset.element_spec[0].shape[1:4].as_list()) + + @property + def scalar_labels(self): + return _labels_all_scalar([y for _, y in self.dataset.as_numpy_iterator()]) + + def get_steps_per_epoch(self): + def get_n(a, k): + return (a - k) / k + 1 + + n_blocks = tuple( + get_n(aa, kk) for aa, kk in zip(self.volume_shape, self.block_shape) ) + + for n in n_blocks: + if not n.is_integer() or n < 1: + raise ValueError( + "cannot create non-overlapping blocks with the given parameters." + ) + n_blocks_per_volume = np.prod(n_blocks).astype(int) + + steps = n_blocks_per_volume * self.n_volumes / self.batch_size + steps = math.ceil(steps) + return steps + + def map(self, func, num_parallel_calls=AUTOTUNE): + self.dataset = self.dataset.map(func, num_parallel_calls=num_parallel_calls) + return self + + def normalize(self, normalizer): + return self.map(lambda x, y: (normalizer(x), y)) + + def augment(self, augment_steps, num_parallel_calls=AUTOTUNE): + batch_size = None + if len(self.dataset.element_spec[0].shape) > 4: + batch_size = self.batch_size + self.dataset = self.dataset.unbatch() + + for transform, kwargs in augment_steps: + self.map( + lambda x, y: tf.cond( + tf.random.uniform((1,)) > 0.5, + true_fn=lambda: transform(x, y, **kwargs), + false_fn=lambda: (x, y), + ), + num_parallel_calls=num_parallel_calls, + ) + + if batch_size: + self.batch(batch_size) + + return self + + def block(self, block_shape, num_parallel_calls=AUTOTUNE): + if not self.scalar_labels: + self.map( + lambda x, y: (to_blocks(x, block_shape), to_blocks(y, block_shape)), + num_parallel_calls=num_parallel_calls, + ) + else: + + def _f(x, y): + x = to_blocks(x, block_shape) + n_blocks = x.shape[0] + y = tf.repeat(y, n_blocks) + return (x, y) + + self.map(_f, num_parallel_calls=num_parallel_calls) + # This step is necessary because separating into blocks adds a dimension. + self.dataset = self.dataset.unbatch() + return self + + def map_labels(self, label_mapping=None): + if self.n_classes < 1: + raise ValueError("n_classes must be > 0.") + + if label_mapping is not None: + self.map(lambda x, y: (x, replace(y, label_mapping=label_mapping))) + + if self.n_classes == 1: + self.map(lambda x, y: (x, tf.expand_dims(binarize(y), -1))) + elif self.n_classes == 2: + self.map(lambda x, y: (x, tf.one_hot(binarize(y), self.n_classes))) + elif self.n_classes > 2: + self.map(lambda x, y: (x, tf.one_hot(y, self.n_classes))) + + return self + + def batch(self, batch_size): + # If volume_shape is only three dims, add grayscale channel to features. + # Otherwise, assume that the channels are already in the features. + if len(self.dataset.element_spec[0].shape) == 3: + self.map(lambda x, y: (tf.expand_dims(x, -1), y)) + elif len(self.dataset.element_spec[0].shape) > 4: + self.dataset = self.dataset.unbatch() + + # Prefetch data to overlap data production with data consumption. The + # TensorFlow documentation suggests prefetching `batch_size` elements. + self.dataset = self.dataset.prefetch(buffer_size=batch_size) + + # Batch the dataset, so each iteration gives `batch_size` elements. We drop + # the remainder so that when training on multiple GPUs, the batch will + # always be evenly divisible by the number of GPUs. Otherwise, the last + # batch might have fewer than `batch_size` elements and will cause errors. + self.dataset = self.dataset.batch(batch_size=batch_size, drop_remainder=True) + + return self + + def shuffle(self, shuffle_buffer_size): + # Optionally shuffle. We also optionally shuffle the list of files. + # The TensorFlow recommend shuffling and then repeating. + self.dataset = self.dataset.shuffle(buffer_size=shuffle_buffer_size) + return self + + def repeat(self, n_repeats): + # Repeat the dataset for n_epochs. If n_epochs is None, then repeat + # indefinitely. If n_epochs is 1, then the dataset will only be iterated + # through once. + self.dataset = self.dataset.repeat(n_repeats) + return self diff --git a/nobrainer/processing/segmentation.py b/nobrainer/processing/segmentation.py index 5d1e5484..480f146c 100644 --- a/nobrainer/processing/segmentation.py +++ b/nobrainer/processing/segmentation.py @@ -5,7 +5,6 @@ from .base import BaseEstimator from .. import losses, metrics -from ..dataset import get_steps_per_epoch logging.getLogger().setLevel(logging.INFO) @@ -44,15 +43,11 @@ def fit( """Train a segmentation model""" # TODO: check validity of datasets - # extract dataset information - batch_size = dataset_train.element_spec[0].shape[0] - self.block_shape_ = tuple(dataset_train.element_spec[0].shape[1:4]) + batch_size = dataset_train.batch_size + self.block_shape_ = dataset_train.block_shape self.volume_shape_ = dataset_train.volume_shape - self.scalar_labels_ = True - n_classes = 1 - if len(dataset_train.element_spec[1].shape) > 1: - n_classes = dataset_train.element_spec[1].shape[4] - self.scalar_labels_ = False + self.scalar_labels_ = dataset_train.scalar_labels + n_classes = dataset_train.n_classes opt_args = opt_args or {} if optimizer is None: optimizer = tf.keras.optimizers.Adam @@ -87,32 +82,17 @@ def _compile(): _compile() self.model_.summary() - train_steps = get_steps_per_epoch( - n_volumes=dataset_train.n_volumes, - volume_shape=self.volume_shape_, - block_shape=self.block_shape_, - batch_size=batch_size, - ) - - evaluate_steps = None - if dataset_validate is not None: - evaluate_steps = get_steps_per_epoch( - n_volumes=dataset_validate.n_volumes, - volume_shape=self.volume_shape_, - block_shape=self.block_shape_, - batch_size=batch_size, - ) - callbacks = [] if self.checkpoint_tracker: callbacks.append(self.checkpoint_tracker) - self.model_.fit( - dataset_train, + dataset_train.dataset, epochs=epochs, - steps_per_epoch=train_steps, - validation_data=dataset_validate, - validation_steps=evaluate_steps, + steps_per_epoch=dataset_train.get_steps_per_epoch(), + validation_data=dataset_validate.dataset if dataset_validate else None, + validation_steps=dataset_validate.get_steps_per_epoch() + if dataset_validate + else None, callbacks=callbacks, ) diff --git a/nobrainer/tests/checkpoint_test.py b/nobrainer/tests/checkpoint_test.py index b3d03716..e780c824 100644 --- a/nobrainer/tests/checkpoint_test.py +++ b/nobrainer/tests/checkpoint_test.py @@ -6,6 +6,7 @@ from numpy.testing import assert_allclose import tensorflow as tf +from nobrainer.dataset import Dataset from nobrainer.models import meshnet from nobrainer.processing.segmentation import Segmentation @@ -15,10 +16,7 @@ def _get_toy_dataset(): train = tf.data.Dataset.from_tensors( (np.random.rand(*data_shape), np.random.randint(0, 1, data_shape)) ) - train.scalar_labels = False - train.n_volumes = data_shape[0] - train.volume_shape = data_shape[1:4] - return train + return Dataset(train, data_shape[0], data_shape[1:4], 1) def _assert_model_weights_allclose(model1, model2): diff --git a/nobrainer/tests/dataset_test.py b/nobrainer/tests/dataset_test.py index 4fe9210f..5bbe96a2 100644 --- a/nobrainer/tests/dataset_test.py +++ b/nobrainer/tests/dataset_test.py @@ -41,18 +41,20 @@ def test_get_dataset_maintains_order( filepaths, temp_dir, examples_per_shard=examples_per_shard ) volume_shape = (256, 256, 256) - dset = dataset.get_dataset( + dset = dataset.Dataset.from_tfrecords( file_pattern=file_pattern, - n_classes=1, - batch_size=batch_size, + n_volumes=10, volume_shape=volume_shape, - scalar_label=True, - n_epochs=1, + scalar_labels=True, + n_classes=1, num_parallel_calls=num_parallel_calls, - ) + ).batch(batch_size) + y_orig = np.array([y for _, y in filepaths]) y_from_dset = ( - np.concatenate([y for _, y in dset.as_numpy_iterator()]).flatten().astype(int) + np.concatenate([y for _, y in dset.dataset.as_numpy_iterator()]) + .flatten() + .astype(int) ) assert_array_equal(y_orig, y_from_dset) shutil.rmtree(temp_dir) @@ -72,11 +74,11 @@ def test_get_dataset_errors(): temp_dir = tempfile.mkdtemp() file_pattern = op.join(temp_dir, "does_not_exist-*.tfrec") with pytest.raises(ValueError): - dataset.get_dataset( - file_pattern=file_pattern, + dataset.Dataset.from_tfrecords( + file_pattern, + None, + (256, 256, 256), n_classes=1, - batch_size=1, - volume_shape=(256, 256, 256), ) @@ -93,19 +95,18 @@ def test_get_dataset_shapes( file_pattern = write_tfrecs( filepaths, temp_dir, examples_per_shard=examples_per_shard ) - dset = dataset.get_dataset( + dset = dataset.Dataset.from_tfrecords( file_pattern=file_pattern, - n_classes=1, - batch_size=batch_size, + n_volumes=len(filepaths), volume_shape=volume_shape, - scalar_label=True, - n_epochs=1, + scalar_labels=True, + n_classes=1, num_parallel_calls=num_parallel_calls, - ) + ).batch(batch_size) output_volume_shape = volume_shape if len(volume_shape) > 3 else volume_shape + (1,) output_volume_shape = (batch_size,) + output_volume_shape - shapes = [x.shape for x, _ in dset.as_numpy_iterator()] + shapes = [x.shape for x, _ in dset.dataset.as_numpy_iterator()] assert all([_shape == output_volume_shape for _shape in shapes]) shutil.rmtree(temp_dir) @@ -114,19 +115,19 @@ def test_get_dataset_errors_augmentation(): temp_dir = tempfile.mkdtemp() file_pattern = op.join(temp_dir, "does_not_exist-*.tfrec") with pytest.raises(ValueError): - dataset.get_dataset( + dataset.Dataset.from_tfrecords( file_pattern=file_pattern, - n_classes=1, - batch_size=1, + n_volumes=10, volume_shape=(256, 256, 256), - augment=[ - ( - intensity_transforms.addGaussianNoise, - {"noise_mean": 0.1, "noise_std": 0.5}, - ), - (spatial_transforms.randomflip_leftright), - ], - ) + n_classes=1, + ).augment = [ + ( + intensity_transforms.addGaussianNoise, + {"noise_mean": 0.1, "noise_std": 0.5}, + ), + (spatial_transforms.randomflip_leftright), + ] + shutil.rmtree(temp_dir) # TODO: need to implement this soon. @@ -136,31 +137,49 @@ def test_get_dataset(): def test_get_steps_per_epoch(): - nsteps = dataset.get_steps_per_epoch( + volume_shape = (256, 256, 256) + temp_dir = tempfile.mkdtemp() + nifti_paths = create_dummy_niftis(volume_shape, 10, temp_dir) + filepaths = [(x, i) for i, x in enumerate(nifti_paths)] + file_pattern = write_tfrecs(filepaths, temp_dir, examples_per_shard=1) + dset = dataset.Dataset.from_tfrecords( + file_pattern=file_pattern.replace("*", "000"), n_volumes=1, - volume_shape=(256, 256, 256), + volume_shape=volume_shape, block_shape=(64, 64, 64), - batch_size=1, + scalar_labels=True, + n_classes=1, ) - assert nsteps == 64 - nsteps = dataset.get_steps_per_epoch( + assert dset.get_steps_per_epoch() == 64 + + dset = dataset.Dataset.from_tfrecords( + file_pattern=file_pattern.replace("*", "000"), n_volumes=1, - volume_shape=(256, 256, 256), + volume_shape=volume_shape, block_shape=(64, 64, 64), - batch_size=64, - ) - assert nsteps == 1 - nsteps = dataset.get_steps_per_epoch( + scalar_labels=True, + n_classes=1, + ).batch(64) + assert dset.get_steps_per_epoch() == 1 + + dset = dataset.Dataset.from_tfrecords( + file_pattern=file_pattern.replace("*", "000"), n_volumes=1, - volume_shape=(256, 256, 256), + volume_shape=volume_shape, block_shape=(64, 64, 64), - batch_size=63, - ) - assert nsteps == 2 - nsteps = dataset.get_steps_per_epoch( + scalar_labels=True, + n_classes=1, + ).batch(63) + assert dset.get_steps_per_epoch() == 2 + + dset = dataset.Dataset.from_tfrecords( + file_pattern=file_pattern, n_volumes=10, - volume_shape=(256, 256, 256), + volume_shape=volume_shape, block_shape=(128, 128, 128), - batch_size=4, - ) - assert nsteps == 20 + scalar_labels=True, + n_classes=1, + ).batch(4) + assert dset.get_steps_per_epoch() == 20 + + shutil.rmtree(temp_dir) diff --git a/nobrainer/tests/tfrecord_test.py b/nobrainer/tests/tfrecord_test.py index 55ff205a..2cbb9b41 100644 --- a/nobrainer/tests/tfrecord_test.py +++ b/nobrainer/tests/tfrecord_test.py @@ -27,7 +27,7 @@ def test_write_read_volume_labels(csv_of_volumes, tmp_path): # noqa: F811 dset = tf.data.TFRecordDataset(list(map(str, paths)), compression_type="GZIP") dset = dset.map( - tfrecord.parse_example_fn(volume_shape=(8, 8, 8), scalar_label=False) + tfrecord.parse_example_fn(volume_shape=(8, 8, 8), scalar_labels=False) ) for ref, test in zip(files, dset): @@ -60,7 +60,7 @@ def test_write_read_volume_labels_all_processes(csv_of_volumes, tmp_path): # no dset = tf.data.TFRecordDataset(list(map(str, paths)), compression_type="GZIP") dset = dset.map( - tfrecord.parse_example_fn(volume_shape=(8, 8, 8), scalar_label=False) + tfrecord.parse_example_fn(volume_shape=(8, 8, 8), scalar_labels=False) ) for ref, test in zip(files, dset): @@ -94,7 +94,7 @@ def test_write_read_float_labels(csv_of_volumes, tmp_path): # noqa: F811 dset = tf.data.TFRecordDataset(list(map(str, paths)), compression_type="GZIP") dset = dset.map( - tfrecord.parse_example_fn(volume_shape=(8, 8, 8), scalar_label=True) + tfrecord.parse_example_fn(volume_shape=(8, 8, 8), scalar_labels=True) ) for ref, test in zip(files, dset): @@ -123,7 +123,7 @@ def test_write_read_int_labels(csv_of_volumes, tmp_path): # noqa: F811 dset = tf.data.TFRecordDataset(list(map(str, paths)), compression_type="GZIP") dset = dset.map( - tfrecord.parse_example_fn(volume_shape=(8, 8, 8), scalar_label=True) + tfrecord.parse_example_fn(volume_shape=(8, 8, 8), scalar_labels=True) ) for ref, test in zip(files, dset): diff --git a/nobrainer/tfrecord.py b/nobrainer/tfrecord.py index 39ec8c97..405ba87d 100644 --- a/nobrainer/tfrecord.py +++ b/nobrainer/tfrecord.py @@ -105,14 +105,14 @@ def write( get_reusable_executor().shutdown(wait=True) -def parse_example_fn(volume_shape, scalar_label=False): +def parse_example_fn(volume_shape, scalar_labels=False): """Return function that can be used to read TFRecord file into tensors. Parameters ---------- - volume_shape: sequence, the shape of the feature data. If `scalar_label` is `False`, + volume_shape: sequence, the shape of the feature data. If `scalar_labels` is `False`, this also corresponds to the shape of the label data. - scalar_label: boolean, if `True`, label is a scalar. If `False`, label must be the + scalar_labels: boolean, if `True`, label is a scalar. If `False`, label must be the same shape as feature data. Returns @@ -130,7 +130,7 @@ def parse_example(serialized): Returns ------- - Tuple of two tensors. If `scalar_label` is `False`, both tensors have shape + Tuple of two tensors. If `scalar_labels` is `False`, both tensors have shape `volume_shape`. Otherwise, the first tensor has shape `volume_shape`, and the second is a scalar tensor. """ @@ -148,7 +148,7 @@ def parse_example(serialized): # xshape = tf.cast( # tf.io.decode_raw(e["feature/shape"], _TFRECORDS_DTYPE), tf.int32) x = tf.reshape(x, shape=volume_shape) - if not scalar_label: + if not scalar_labels: y = tf.reshape(y, shape=volume_shape) else: y = tf.reshape(y, shape=[1]) @@ -285,7 +285,7 @@ def __init__( # files, though it is possible to have existing filenames that # are integers or floats. labels = [y for _, y in features_labels] - self.scalar_label = _labels_all_scalar(labels) + self.scalar_labels = _labels_all_scalar(labels) self._j = 0 no_exist = [] @@ -295,7 +295,7 @@ def __init__( if no_exist: raise ValueError("Some files do not exist: {}".format(", ".join(no_exist))) - if not self.scalar_label: + if not self.scalar_labels: no_exist = [] for _, y in self.features_labels: if not Path(y).exists(): @@ -325,7 +325,7 @@ def _serialize(self, index): ) if self.multi_resolution: # only scalar label - if not self.scalar_label: + if not self.scalar_labels: y = 0 proto_dict = {} for resolution in self.resolutions[::-1]: @@ -346,7 +346,7 @@ def _serialize(self, index): return proto_dict else: label_affine = None - if not self.scalar_label: + if not self.scalar_labels: y, label_affine = read_volume( y, return_affine=True, dtype=_TFRECORDS_DTYPE, to_ras=self.to_ras )