Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 24, 2023
1 parent 2f51db2 commit 9286eb7
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions nobrainer/tests/checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
def _get_toy_dataset():
data_shape = (8, 8, 8, 8, 1)
train = tf.data.Dataset.from_tensors(
(np.random.rand(*data_shape),
np.random.randint(0, 1, data_shape))
(np.random.rand(*data_shape), np.random.randint(0, 1, data_shape))
)
train.scalar_labels = False
train.n_volumes = data_shape[0]
train.volume_shape = data_shape[1:4]
return train


def _assert_model_weights_allclose(model1, model2):
for layer1, layer2 in zip(model1.model.layers, model2.model.layers):
weights1 = layer1.get_weights()
Expand Down Expand Up @@ -51,11 +51,12 @@ def test_checkpoint(tmp_path):
model3 = Segmentation.load_latest(checkpoint_filepath=checkpoint_filepath)
_assert_model_weights_allclose(model2, model3)


def test_warm_start_workflow(tmp_path):
train = _get_toy_dataset()

checkpoint_dir = os.path.join('checkpoints')
checkpoint_filepath = os.path.join(checkpoint_dir, '{epoch:03d}')
checkpoint_dir = os.path.join("checkpoints")
checkpoint_filepath = os.path.join(checkpoint_dir, "{epoch:03d}")
if not os.path.exists(checkpoint_dir):
os.mkdir(checkpoint_dir)

Expand Down

0 comments on commit 9286eb7

Please sign in to comment.