Skip to content

Commit

Permalink
Small changes to support long, preemptable training runs
Browse files Browse the repository at this point in the history
  • Loading branch information
ohinds committed Oct 2, 2023
1 parent 6dc2b4c commit f41cf96
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
7 changes: 7 additions & 0 deletions nobrainer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def from_tfrecords(
block_shape=None,
scalar_labels=False,
n_classes=1,
options=None,
num_parallel_calls=1,
):
"""Function to retrieve a saved tf record as a nobrainer Dataset
Expand All @@ -97,6 +98,9 @@ def from_tfrecords(
compressed = _is_gzipped(files[0], filesys=fs)
dataset = tf.data.Dataset.list_files(file_pattern, shuffle=False)

if options:
dataset = dataset.with_options(options)

# Read each of these files as a TFRecordDataset.
# Assume all files have same compression type as the first file.
compression_type = "GZIP" if compressed else None
Expand Down Expand Up @@ -150,6 +154,7 @@ def from_files(
eval_size=0.1,
n_classes=1,
block_shape=None,
options=None,
):
"""Create Nobrainer datasets from data
filepaths: List(str), list of paths to individual input data files.
Expand Down Expand Up @@ -211,6 +216,7 @@ def from_files(
scalar_labels=scalar_labels,
n_classes=n_classes,
block_shape=block_shape,
options=options,
num_parallel_calls=num_parallel_calls,
)
ds_eval = None
Expand All @@ -223,6 +229,7 @@ def from_files(
scalar_labels=scalar_labels,
n_classes=n_classes,
block_shape=block_shape,
options=options,
num_parallel_calls=num_parallel_calls,
)
return ds_train, ds_eval
Expand Down
5 changes: 5 additions & 0 deletions nobrainer/processing/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,13 @@ def load(
"""
checkpoints = glob(os.path.join(os.path.dirname(self.filepath), "*/"))
if not checkpoints:
self.last_epoch = 0
return None

# TODO, we should probably exclude non-checkpoint files here,
# and maybe parse the filename for the epoch number
self.last_epoch = len(checkpoints)

latest = max(checkpoints, key=os.path.getctime)
self.estimator = self.estimator.load(
latest,
Expand Down

0 comments on commit f41cf96

Please sign in to comment.