diff --git a/nobrainer/dataset.py b/nobrainer/dataset.py index df8136ac..acb5f97e 100644 --- a/nobrainer/dataset.py +++ b/nobrainer/dataset.py @@ -78,6 +78,7 @@ def from_tfrecords( block_shape=None, scalar_labels=False, n_classes=1, + options=None, num_parallel_calls=1, ): """Function to retrieve a saved tf record as a nobrainer Dataset @@ -97,6 +98,9 @@ def from_tfrecords( compressed = _is_gzipped(files[0], filesys=fs) dataset = tf.data.Dataset.list_files(file_pattern, shuffle=False) + if options: + dataset = dataset.with_options(options) + # Read each of these files as a TFRecordDataset. # Assume all files have same compression type as the first file. compression_type = "GZIP" if compressed else None @@ -150,6 +154,7 @@ def from_files( eval_size=0.1, n_classes=1, block_shape=None, + options=None, ): """Create Nobrainer datasets from data filepaths: List(str), list of paths to individual input data files. @@ -211,6 +216,7 @@ def from_files( scalar_labels=scalar_labels, n_classes=n_classes, block_shape=block_shape, + options=options, num_parallel_calls=num_parallel_calls, ) ds_eval = None @@ -223,6 +229,7 @@ def from_files( scalar_labels=scalar_labels, n_classes=n_classes, block_shape=block_shape, + options=options, num_parallel_calls=num_parallel_calls, ) return ds_train, ds_eval diff --git a/nobrainer/processing/checkpoint.py b/nobrainer/processing/checkpoint.py index e9cb9585..bf54ef43 100644 --- a/nobrainer/processing/checkpoint.py +++ b/nobrainer/processing/checkpoint.py @@ -47,8 +47,13 @@ def load( """ checkpoints = glob(os.path.join(os.path.dirname(self.filepath), "*/")) if not checkpoints: + self.last_epoch = 0 return None + # TODO, we should probably exclude non-checkpoint files here, + # and maybe parse the filename for the epoch number + self.last_epoch = len(checkpoints) + latest = max(checkpoints, key=os.path.getctime) self.estimator = self.estimator.load( latest,