From 29991dd938d4ddc85c30625119372fd42262a37a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 22 Aug 2023 18:46:12 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- nobrainer/processing/base.py | 2 ++ nobrainer/processing/checkpoint.py | 3 ++- nobrainer/processing/segmentation.py | 5 +++-- nobrainer/tests/checkpoint_test.py | 14 ++++++++------ 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/nobrainer/processing/base.py b/nobrainer/processing/base.py index 5f22be7c..5d2f80b2 100644 --- a/nobrainer/processing/base.py +++ b/nobrainer/processing/base.py @@ -24,6 +24,7 @@ def __init__(self, checkpoint_filepath=None, multi_gpu=False): self.checkpoint_tracker = None if checkpoint_filepath: from .checkpoint import CheckpointTracker + self.checkpoint_tracker = CheckpointTracker(self, checkpoint_filepath) self.strategy = get_strategy(multi_gpu) @@ -78,6 +79,7 @@ def load(cls, model_dir, multi_gpu=False, custom_objects=None, compile=False): @classmethod def load_latest(cls, checkpoint_filepath): from .checkpoint import CheckpointTracker + checkpoint_tracker = CheckpointTracker(cls, checkpoint_filepath) estimator = checkpoint_tracker.load() estimator.checkpoint_tracker = checkpoint_tracker diff --git a/nobrainer/processing/checkpoint.py b/nobrainer/processing/checkpoint.py index 1787922b..3b39c75e 100644 --- a/nobrainer/processing/checkpoint.py +++ b/nobrainer/processing/checkpoint.py @@ -3,6 +3,7 @@ from glob import glob import logging import os + import tensorflow as tf from .base import BaseEstimator @@ -40,7 +41,7 @@ def load(self): """Loads the most-recently created checkpoint from the checkpoint directory. """ - checkpoints = glob(os.path.join(os.path.dirname(self.filepath), '*/')) + checkpoints = glob(os.path.join(os.path.dirname(self.filepath), "*/")) latest = max(checkpoints, key=os.path.getctime) self.estimator = self.estimator.load(latest) logging.info(f"Loaded estimator from {latest}.") diff --git a/nobrainer/processing/segmentation.py b/nobrainer/processing/segmentation.py index 8a4c3ea1..e47d730e 100644 --- a/nobrainer/processing/segmentation.py +++ b/nobrainer/processing/segmentation.py @@ -8,7 +8,6 @@ from .. import losses, metrics from ..dataset import get_steps_per_epoch - logging.getLogger().setLevel(logging.INFO) @@ -17,7 +16,9 @@ class Segmentation(BaseEstimator): state_variables = ["block_shape_", "volume_shape_", "scalar_labels_"] - def __init__(self, base_model, model_args=None, checkpoint_filepath=None, multi_gpu=False): + def __init__( + self, base_model, model_args=None, checkpoint_filepath=None, multi_gpu=False + ): super().__init__(checkpoint_filepath=checkpoint_filepath, multi_gpu=multi_gpu) if not isinstance(base_model, str): diff --git a/nobrainer/tests/checkpoint_test.py b/nobrainer/tests/checkpoint_test.py index 197e55ca..785c49a0 100644 --- a/nobrainer/tests/checkpoint_test.py +++ b/nobrainer/tests/checkpoint_test.py @@ -1,13 +1,15 @@ """Tests for `nobrainer.processing.checkpoint`.""" -from nobrainer.processing.segmentation import Segmentation -from nobrainer.models import meshnet +import os + import numpy as np from numpy.testing import assert_allclose -import os import pytest import tensorflow as tf +from nobrainer.models import meshnet +from nobrainer.processing.segmentation import Segmentation + def _assert_model_weights_allclose(model1, model2): for layer1, layer2 in zip(model1.model.layers, model2.model.layers): @@ -17,17 +19,17 @@ def _assert_model_weights_allclose(model1, model2): for index in range(len(weights1)): assert_allclose(weights1[index], weights2[index], rtol=1e-06, atol=1e-08) + def test_checkpoint(tmp_path): data_shape = (8, 8, 8, 8, 1) train = tf.data.Dataset.from_tensors( - (np.random.rand(*data_shape), - np.random.randint(0, 1, data_shape)) + (np.random.rand(*data_shape), np.random.randint(0, 1, data_shape)) ) train.scalar_labels = False train.n_volumes = data_shape[0] train.volume_shape = data_shape[1:4] - checkpoint_filepath = os.path.join(tmp_path, 'checkpoint-epoch_{epoch:03d}') + checkpoint_filepath = os.path.join(tmp_path, "checkpoint-epoch_{epoch:03d}") model1 = Segmentation(meshnet, checkpoint_filepath=checkpoint_filepath) model1.fit( dataset_train=train,