Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 22, 2023
1 parent c7d7cfe commit 29991dd
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 9 deletions.
2 changes: 2 additions & 0 deletions nobrainer/processing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self, checkpoint_filepath=None, multi_gpu=False):
self.checkpoint_tracker = None
if checkpoint_filepath:
from .checkpoint import CheckpointTracker

self.checkpoint_tracker = CheckpointTracker(self, checkpoint_filepath)

self.strategy = get_strategy(multi_gpu)
Expand Down Expand Up @@ -78,6 +79,7 @@ def load(cls, model_dir, multi_gpu=False, custom_objects=None, compile=False):
@classmethod
def load_latest(cls, checkpoint_filepath):
from .checkpoint import CheckpointTracker

checkpoint_tracker = CheckpointTracker(cls, checkpoint_filepath)
estimator = checkpoint_tracker.load()
estimator.checkpoint_tracker = checkpoint_tracker
Expand Down
3 changes: 2 additions & 1 deletion nobrainer/processing/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from glob import glob
import logging
import os

import tensorflow as tf

from .base import BaseEstimator
Expand Down Expand Up @@ -40,7 +41,7 @@ def load(self):
"""Loads the most-recently created checkpoint from the
checkpoint directory.
"""
checkpoints = glob(os.path.join(os.path.dirname(self.filepath), '*/'))
checkpoints = glob(os.path.join(os.path.dirname(self.filepath), "*/"))
latest = max(checkpoints, key=os.path.getctime)
self.estimator = self.estimator.load(latest)
logging.info(f"Loaded estimator from {latest}.")
Expand Down
5 changes: 3 additions & 2 deletions nobrainer/processing/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from .. import losses, metrics
from ..dataset import get_steps_per_epoch


logging.getLogger().setLevel(logging.INFO)


Expand All @@ -17,7 +16,9 @@ class Segmentation(BaseEstimator):

state_variables = ["block_shape_", "volume_shape_", "scalar_labels_"]

def __init__(self, base_model, model_args=None, checkpoint_filepath=None, multi_gpu=False):
def __init__(
self, base_model, model_args=None, checkpoint_filepath=None, multi_gpu=False
):
super().__init__(checkpoint_filepath=checkpoint_filepath, multi_gpu=multi_gpu)

if not isinstance(base_model, str):
Expand Down
14 changes: 8 additions & 6 deletions nobrainer/tests/checkpoint_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Tests for `nobrainer.processing.checkpoint`."""

from nobrainer.processing.segmentation import Segmentation
from nobrainer.models import meshnet
import os

import numpy as np
from numpy.testing import assert_allclose
import os
import pytest
import tensorflow as tf

from nobrainer.models import meshnet
from nobrainer.processing.segmentation import Segmentation


def _assert_model_weights_allclose(model1, model2):
for layer1, layer2 in zip(model1.model.layers, model2.model.layers):
Expand All @@ -17,17 +19,17 @@ def _assert_model_weights_allclose(model1, model2):
for index in range(len(weights1)):
assert_allclose(weights1[index], weights2[index], rtol=1e-06, atol=1e-08)


def test_checkpoint(tmp_path):
data_shape = (8, 8, 8, 8, 1)
train = tf.data.Dataset.from_tensors(
(np.random.rand(*data_shape),
np.random.randint(0, 1, data_shape))
(np.random.rand(*data_shape), np.random.randint(0, 1, data_shape))
)
train.scalar_labels = False
train.n_volumes = data_shape[0]
train.volume_shape = data_shape[1:4]

checkpoint_filepath = os.path.join(tmp_path, 'checkpoint-epoch_{epoch:03d}')
checkpoint_filepath = os.path.join(tmp_path, "checkpoint-epoch_{epoch:03d}")
model1 = Segmentation(meshnet, checkpoint_filepath=checkpoint_filepath)
model1.fit(
dataset_train=train,
Expand Down

0 comments on commit 29991dd

Please sign in to comment.