diff --git a/nobrainer/processing/base.py b/nobrainer/processing/base.py index 7acef848..f4df69fe 100644 --- a/nobrainer/processing/base.py +++ b/nobrainer/processing/base.py @@ -20,7 +20,7 @@ class BaseEstimator: state_variables = [] model_ = None - def __init__(self, checkpoint_filepath=None, multi_gpu=False): + def __init__(self, checkpoint_filepath=None, multi_gpu=True): self.checkpoint_tracker = None if checkpoint_filepath: from .checkpoint import CheckpointTracker @@ -54,7 +54,13 @@ def save(self, save_dir): pk.dump(model_info, fp) @classmethod - def load(cls, model_dir, multi_gpu=False, custom_objects=None, compile=False): + def load( + cls, + model_dir, + multi_gpu=True, + custom_objects=None, + compile=False, + ): """Loads a trained model from a save directory""" model_dir = Path(str(model_dir).rstrip(os.pathsep)) assert model_dir.exists() and model_dir.is_dir() @@ -64,20 +70,31 @@ def load(cls, model_dir, multi_gpu=False, custom_objects=None, compile=False): if model_info["classname"] != cls.__name__: raise ValueError(f"Model class does not match {cls.__name__}") del model_info["classname"] + klass = cls(**model_info["__init__"]) del model_info["__init__"] for key, value in model_info.items(): setattr(klass, key, value) - klass.strategy = get_strategy(multi_gpu) + klass.strategy = get_strategy(multi_gpu) with klass.strategy.scope(): klass.model_ = tf.keras.models.load_model( - model_dir, custom_objects=custom_objects, compile=compile + model_dir, + custom_objects=custom_objects, + compile=compile, ) return klass @classmethod - def init_with_checkpoints(cls, model_name, checkpoint_filepath): + def init_with_checkpoints( + cls, + model_name, + checkpoint_filepath, + multi_gpu=True, + custom_objects=None, + compile=False, + model_args=None, + ): """Initialize a model for training, either from the latest checkpoint found, or from scratch if no checkpoints are found. This is useful for long-running model fits that may be @@ -96,9 +113,13 @@ def init_with_checkpoints(cls, model_name, checkpoint_filepath): from .checkpoint import CheckpointTracker checkpoint_tracker = CheckpointTracker(cls, checkpoint_filepath) - estimator = checkpoint_tracker.load() + estimator = checkpoint_tracker.load( + multi_gpu=multi_gpu, + custom_objects=custom_objects, + compile=compile, + ) if not estimator: - estimator = cls(model_name) + estimator = cls(model_name, model_args=model_args) estimator.checkpoint_tracker = checkpoint_tracker checkpoint_tracker.estimator = estimator return estimator diff --git a/nobrainer/processing/checkpoint.py b/nobrainer/processing/checkpoint.py index 7fb3d884..e9cb9585 100644 --- a/nobrainer/processing/checkpoint.py +++ b/nobrainer/processing/checkpoint.py @@ -35,7 +35,13 @@ def save(self, directory): logging.info(f"Saving to dir {directory}") self.estimator.save(directory) - def load(self): + def load( + self, + multi_gpu=True, + custom_objects=None, + compile=False, + model_args=None, + ): """Loads the most-recently created checkpoint from the checkpoint directory. """ @@ -44,6 +50,11 @@ def load(self): return None latest = max(checkpoints, key=os.path.getctime) - self.estimator = self.estimator.load(latest) + self.estimator = self.estimator.load( + latest, + multi_gpu=multi_gpu, + custom_objects=custom_objects, + compile=compile, + ) logging.info(f"Loaded estimator from {latest}.") return self.estimator diff --git a/nobrainer/processing/generation.py b/nobrainer/processing/generation.py index c7b42cc5..bc0da104 100644 --- a/nobrainer/processing/generation.py +++ b/nobrainer/processing/generation.py @@ -21,7 +21,7 @@ def __init__( dimensionality=3, g_fmap_base=1024, d_fmap_base=1024, - multi_gpu=False, + multi_gpu=True, ): super().__init__(multi_gpu=multi_gpu) self.model_ = None diff --git a/nobrainer/processing/segmentation.py b/nobrainer/processing/segmentation.py index 480f146c..66e1768d 100644 --- a/nobrainer/processing/segmentation.py +++ b/nobrainer/processing/segmentation.py @@ -15,7 +15,7 @@ class Segmentation(BaseEstimator): state_variables = ["block_shape_", "volume_shape_", "scalar_labels_"] def __init__( - self, base_model, model_args=None, checkpoint_filepath=None, multi_gpu=False + self, base_model, model_args=None, checkpoint_filepath=None, multi_gpu=True ): super().__init__(checkpoint_filepath=checkpoint_filepath, multi_gpu=multi_gpu)