Skip to content

Commit

Permalink
Changes required to support the warmstart guide notebook (#266)
Browse files Browse the repository at this point in the history
  • Loading branch information
ohinds authored Sep 19, 2023
1 parent 7de8638 commit 6b242e5
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 11 deletions.
35 changes: 28 additions & 7 deletions nobrainer/processing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class BaseEstimator:
state_variables = []
model_ = None

def __init__(self, checkpoint_filepath=None, multi_gpu=False):
def __init__(self, checkpoint_filepath=None, multi_gpu=True):
self.checkpoint_tracker = None
if checkpoint_filepath:
from .checkpoint import CheckpointTracker
Expand Down Expand Up @@ -54,7 +54,13 @@ def save(self, save_dir):
pk.dump(model_info, fp)

@classmethod
def load(cls, model_dir, multi_gpu=False, custom_objects=None, compile=False):
def load(
cls,
model_dir,
multi_gpu=True,
custom_objects=None,
compile=False,
):
"""Loads a trained model from a save directory"""
model_dir = Path(str(model_dir).rstrip(os.pathsep))
assert model_dir.exists() and model_dir.is_dir()
Expand All @@ -64,20 +70,31 @@ def load(cls, model_dir, multi_gpu=False, custom_objects=None, compile=False):
if model_info["classname"] != cls.__name__:
raise ValueError(f"Model class does not match {cls.__name__}")
del model_info["classname"]

klass = cls(**model_info["__init__"])
del model_info["__init__"]
for key, value in model_info.items():
setattr(klass, key, value)
klass.strategy = get_strategy(multi_gpu)

klass.strategy = get_strategy(multi_gpu)
with klass.strategy.scope():
klass.model_ = tf.keras.models.load_model(
model_dir, custom_objects=custom_objects, compile=compile
model_dir,
custom_objects=custom_objects,
compile=compile,
)
return klass

@classmethod
def init_with_checkpoints(cls, model_name, checkpoint_filepath):
def init_with_checkpoints(
cls,
model_name,
checkpoint_filepath,
multi_gpu=True,
custom_objects=None,
compile=False,
model_args=None,
):
"""Initialize a model for training, either from the latest
checkpoint found, or from scratch if no checkpoints are
found. This is useful for long-running model fits that may be
Expand All @@ -96,9 +113,13 @@ def init_with_checkpoints(cls, model_name, checkpoint_filepath):
from .checkpoint import CheckpointTracker

checkpoint_tracker = CheckpointTracker(cls, checkpoint_filepath)
estimator = checkpoint_tracker.load()
estimator = checkpoint_tracker.load(
multi_gpu=multi_gpu,
custom_objects=custom_objects,
compile=compile,
)
if not estimator:
estimator = cls(model_name)
estimator = cls(model_name, model_args=model_args)
estimator.checkpoint_tracker = checkpoint_tracker
checkpoint_tracker.estimator = estimator
return estimator
Expand Down
15 changes: 13 additions & 2 deletions nobrainer/processing/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ def save(self, directory):
logging.info(f"Saving to dir {directory}")
self.estimator.save(directory)

def load(self):
def load(
self,
multi_gpu=True,
custom_objects=None,
compile=False,
model_args=None,
):
"""Loads the most-recently created checkpoint from the
checkpoint directory.
"""
Expand All @@ -44,6 +50,11 @@ def load(self):
return None

latest = max(checkpoints, key=os.path.getctime)
self.estimator = self.estimator.load(latest)
self.estimator = self.estimator.load(
latest,
multi_gpu=multi_gpu,
custom_objects=custom_objects,
compile=compile,
)
logging.info(f"Loaded estimator from {latest}.")
return self.estimator
2 changes: 1 addition & 1 deletion nobrainer/processing/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
dimensionality=3,
g_fmap_base=1024,
d_fmap_base=1024,
multi_gpu=False,
multi_gpu=True,
):
super().__init__(multi_gpu=multi_gpu)
self.model_ = None
Expand Down
2 changes: 1 addition & 1 deletion nobrainer/processing/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class Segmentation(BaseEstimator):
state_variables = ["block_shape_", "volume_shape_", "scalar_labels_"]

def __init__(
self, base_model, model_args=None, checkpoint_filepath=None, multi_gpu=False
self, base_model, model_args=None, checkpoint_filepath=None, multi_gpu=True
):
super().__init__(checkpoint_filepath=checkpoint_filepath, multi_gpu=multi_gpu)

Expand Down

0 comments on commit 6b242e5

Please sign in to comment.