From 9286eb71fe4c1683d2c052aa0d25bfac517dcaaa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 24 Aug 2023 18:04:44 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- nobrainer/tests/checkpoint_test.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/nobrainer/tests/checkpoint_test.py b/nobrainer/tests/checkpoint_test.py index 379dc571..593283f0 100644 --- a/nobrainer/tests/checkpoint_test.py +++ b/nobrainer/tests/checkpoint_test.py @@ -14,14 +14,14 @@ def _get_toy_dataset(): data_shape = (8, 8, 8, 8, 1) train = tf.data.Dataset.from_tensors( - (np.random.rand(*data_shape), - np.random.randint(0, 1, data_shape)) + (np.random.rand(*data_shape), np.random.randint(0, 1, data_shape)) ) train.scalar_labels = False train.n_volumes = data_shape[0] train.volume_shape = data_shape[1:4] return train + def _assert_model_weights_allclose(model1, model2): for layer1, layer2 in zip(model1.model.layers, model2.model.layers): weights1 = layer1.get_weights() @@ -51,11 +51,12 @@ def test_checkpoint(tmp_path): model3 = Segmentation.load_latest(checkpoint_filepath=checkpoint_filepath) _assert_model_weights_allclose(model2, model3) + def test_warm_start_workflow(tmp_path): train = _get_toy_dataset() - checkpoint_dir = os.path.join('checkpoints') - checkpoint_filepath = os.path.join(checkpoint_dir, '{epoch:03d}') + checkpoint_dir = os.path.join("checkpoints") + checkpoint_filepath = os.path.join(checkpoint_dir, "{epoch:03d}") if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir)