Skip to content

Commit

Permalink
Flax checkpoint updates (#472)
Browse files Browse the repository at this point in the history
* Update train state definition

* Fix cases where no checkpoint should be expected

* Remove TF dependence in apply from checkpoint

* Update docstrings

* Add checkpoint exception test

* Pass state when reading checkpoint

* Update checkpoint tests

* Update apply from checkpointing and corresponding tests

* Remove print from test

* Fix apply from checkpoints and update docstrings

* Update checkpoint filenames in modl example scripts

* Use definition of Traversal from Flax while it is available

* Manage case where no last step is found in checkpointing

* Add checkpointing flag to MoDL examples

* Add checkpointing flag to CT unet and odp examples

* Add checkpointing flag to dncnn example

* Add checkpointing flag to deconv odp example

* Bump maximum flax version

* Resolve TracerBoolConversionError in linop power_iteration

* Udpate data submodule

* Revert to non-jittable power_iteration linop _util

* Use jittable power iteration in flax inverse models

* Update data module

* Fix mypy scico flax errors

* Update submodule

* Edit example scripts descriptions

* Edit docstring

* Edit description

* Edit docstring

* Add entry to CHANGES.rst

* Update docstring description

* Update submodule

---------

Co-authored-by: Cristina Garcia-Cardona <[email protected]>
Co-authored-by: Brendt Wohlberg <[email protected]>
  • Loading branch information
3 people authored Dec 6, 2023
1 parent 021d904 commit e378d7a
Show file tree
Hide file tree
Showing 19 changed files with 449 additions and 304 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ Version 0.0.5 (unreleased)
• Rename ``AbelProjector`` to ``AbelTransform``.
• Rename ``solver.ATADSolver`` to ``solver.MatrixATADSolver``.
• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.20.
• Support ``flax`` versions up to 0.7.5.
• Use ``orbax`` for checkpointing ``flax`` models.



Expand Down
33 changes: 20 additions & 13 deletions examples/scripts/ct_astra_modl_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@
"depth": 10,
"num_filters": 64,
"block_depth": 4,
"cg_iter": 3,
"cg_iter_1": 3,
"cg_iter_2": 8,
}
# training configuration
train_conf: sflax.ConfigDict = {
Expand All @@ -132,6 +133,7 @@
"warmup_epochs": 0,
"log_every_steps": 40,
"log": True,
"checkpointing": True,
}


Expand Down Expand Up @@ -166,10 +168,11 @@
)

stats_object_ini = None
stats_object = None

checkpoint_files = []
for dirpath, dirnames, filenames in os.walk(workdir2):
checkpoint_files = [fn for fn in filenames if str.split(fn, "_")[0] == "checkpoint"]
checkpoint_files = [fn for fn in filenames]

if len(checkpoint_files) > 0:
model = sflax.MoDLNet(
Expand All @@ -178,11 +181,14 @@
channels=channels,
num_filters=model_conf["num_filters"],
block_depth=model_conf["block_depth"],
cg_iter=model_conf["cg_iter"],
cg_iter=model_conf["cg_iter_2"],
)

train_conf["workdir"] = workdir2
train_conf["post_lst"] = [lmbdapos]
# Parameters for 2nd stage
train_conf["workdir"] = workdir2
train_conf["opt_type"] = "ADAM"
train_conf["num_epochs"] = 150
# Construct training object
trainer = sflax.BasicFlaxTrainer(
train_conf,
Expand All @@ -203,7 +209,7 @@
channels=channels,
num_filters=model_conf["num_filters"],
block_depth=model_conf["block_depth"],
cg_iter=model_conf["cg_iter"],
cg_iter=model_conf["cg_iter_1"],
)
# First stage: initialization training loop.
workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "modl_ct_out")
Expand All @@ -230,8 +236,7 @@

# Second stage: depth iterations training loop.
model.depth = model_conf["depth"]
model.cg_iter = 8
train_conf["base_learning_rate"] = 1e-2
model.cg_iter = model_conf["cg_iter_2"]
train_conf["opt_type"] = "ADAM"
train_conf["num_epochs"] = 150
train_conf["workdir"] = workdir2
Expand Down Expand Up @@ -265,7 +270,7 @@


"""
Compare trained model in terms of reconstruction time
Evaluate trained model in terms of reconstruction time
and data fidelity.
"""
total_epochs = epochs_init + train_conf["num_epochs"]
Expand All @@ -281,7 +286,9 @@
f"{'PSNR:':6s}{psnr_eval:>5.2f}{' dB'}{'':3s}{'time[s]:':10s}{time_eval:>7.2f}"
)

# Plot comparison
"""
Plot comparison.
"""
np.random.seed(123)
indx = np.random.randint(0, high=maxn)

Expand Down Expand Up @@ -311,10 +318,10 @@


"""
Plot convergence statistics. Statistics only generated if a training
cycle was done (i.e. not reading final epoch results from checkpoint).
Plot convergence statistics. Statistics are generated only if a training
cycle was done (i.e. if not reading final epoch results from checkpoint).
"""
if stats_object is not None:
if stats_object is not None and len(stats_object.iterations) > 0:
hist = stats_object.history(transpose=True)
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))
plot.plot(
Expand All @@ -341,7 +348,7 @@
fig.show()

# Stats for initialization loop
if stats_object_ini is not None:
if stats_object_ini is not None and len(stats_object_ini.iterations) > 0:
hist = stats_object_ini.history(transpose=True)
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))
plot.plot(
Expand Down
13 changes: 8 additions & 5 deletions examples/scripts/ct_astra_odp_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@
"warmup_epochs": 0,
"log_every_steps": 160,
"log": True,
"checkpointing": True,
}


Expand Down Expand Up @@ -208,7 +209,7 @@


"""
Compare trained model in terms of reconstruction time and data fidelity.
Evaluate trained model in terms of reconstruction time and data fidelity.
"""
snr_eval = metric.snr(test_ds["label"][:maxn], output)
psnr_eval = metric.psnr(test_ds["label"][:maxn], output)
Expand All @@ -221,7 +222,9 @@
f"{'PSNR:':6s}{psnr_eval:>5.2f}{' dB'}{'':3s}{'time[s]:':10s}{time_eval:>7.2f}"
)

# Plot comparison
"""
Plot comparison.
"""
np.random.seed(123)
indx = np.random.randint(0, high=maxn)

Expand Down Expand Up @@ -251,10 +254,10 @@


"""
Plot convergence statistics. Statistics only generated if a training
cycle was done (i.e. not reading final epoch results from checkpoint).
Plot convergence statistics. Statistics are generated only if a training
cycle was done (i.e. if not reading final epoch results from checkpoint).
"""
if stats_object is not None:
if stats_object is not None and len(stats_object.iterations) > 0:
hist = stats_object.history(transpose=True)
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))
plot.plot(
Expand Down
27 changes: 17 additions & 10 deletions examples/scripts/ct_astra_unet_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
Read data from cache or generate if not available.
"""
N = 256 # phantom size
train_nimg = 536 # number of training images
test_nimg = 64 # number of testing images
train_nimg = 498 # number of training images
test_nimg = 32 # number of testing images
nimg = train_nimg + test_nimg
n_projection = 45 # CT views

Expand Down Expand Up @@ -83,6 +83,7 @@
"warmup_epochs": 0,
"log_every_steps": 1000,
"log": True,
"checkpointing": True,
}


Expand Down Expand Up @@ -123,18 +124,24 @@
"""
Evaluate on testing data.
"""
start_time = time()
del train_ds["image"]
del train_ds["label"]

fmap = sflax.FlaxMap(model, modvar)
output = fmap(test_ds["image"])
del model, modvar

maxn = test_nimg // 2
start_time = time()
output = fmap(test_ds["image"][:maxn])
time_eval = time() - start_time
output = jax.numpy.clip(output, a_min=0, a_max=1.0)


"""
Compare trained model in terms of reconstruction time and data fidelity.
Evaluate trained model in terms of reconstruction time and data fidelity.
"""
snr_eval = metric.snr(test_ds["label"], output)
psnr_eval = metric.psnr(test_ds["label"], output)
snr_eval = metric.snr(test_ds["label"][:maxn], output)
psnr_eval = metric.psnr(test_ds["label"][:maxn], output)
print(
f"{'UNet training':15s}{'epochs:':2s}{train_conf['num_epochs']:>5d}"
f"{'':21s}{'time[s]:':10s}{time_train:>7.2f}"
Expand Down Expand Up @@ -181,10 +188,10 @@


"""
Plot convergence statistics. Statistics only generated if a training
cycle was done (i.e. not reading final epoch results from checkpoint).
Plot convergence statistics. Statistics are generated only if a training
cycle was done (i.e. if not reading final epoch results from checkpoint).
"""
if stats_object is not None:
if stats_object is not None and len(stats_object.iterations) > 0:
hist = stats_object.history(transpose=True)
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))
plot.plot(
Expand Down
45 changes: 26 additions & 19 deletions examples/scripts/deconv_modl_train_foam1.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@
from functools import partial
from time import time

import numpy as np

import jax
import jax.numpy as jnp

from mpl_toolkits.axes_grid1 import make_axes_locatable

Expand All @@ -72,7 +73,7 @@

n = 3 # convolution kernel size
σ = 20.0 / 255 # noise level
psf = jnp.ones((n, n)) / (n * n) # blur kernel
psf = np.ones((n, n)) / (n * n) # blur kernel

ishape = (output_size, output_size)
opBlur = CircularConvolve(h=psf, input_shape=ishape)
Expand Down Expand Up @@ -127,6 +128,7 @@
"warmup_epochs": 0,
"log_every_steps": 100,
"log": True,
"checkpointing": True,
}


Expand Down Expand Up @@ -161,10 +163,11 @@
)

stats_object_ini = None
stats_object = None

checkpoint_files = []
for dirpath, dirnames, filenames in os.walk(workdir2):
checkpoint_files = [fn for fn in filenames if str.split(fn, "_")[0] == "checkpoint"]
checkpoint_files = [fn for fn in filenames]

if len(checkpoint_files) > 0:
model = sflax.MoDLNet(
Expand Down Expand Up @@ -245,21 +248,25 @@
"""
del train_ds["image"]
del train_ds["label"]
start_time = time()

fmap = sflax.FlaxMap(model, modvar)
output = fmap(test_ds["image"])
del model, modvar

maxn = test_nimg // 4
start_time = time()
output = fmap(test_ds["image"][:maxn])
time_eval = time() - start_time
output = jnp.clip(output, a_min=0, a_max=1.0)
output = np.clip(output, a_min=0, a_max=1.0)


"""
Compare trained model in terms of reconstruction time
Evaluate trained model in terms of reconstruction time
and data fidelity.
"""
total_epochs = epochs_init + train_conf["num_epochs"]
total_time_train = time_init + time_train
snr_eval = metric.snr(test_ds["label"], output)
psnr_eval = metric.psnr(test_ds["label"], output)
snr_eval = metric.snr(test_ds["label"][:maxn], output)
psnr_eval = metric.psnr(test_ds["label"][:maxn], output)
print(
f"{'MoDLNet training':18s}{'epochs:':2s}{total_epochs:>5d}{'':21s}"
f"{'time[s]:':10s}{total_time_train:>7.2f}"
Expand All @@ -273,8 +280,8 @@
"""
Plot comparison.
"""
key = jax.random.PRNGKey(54321)
indx = jax.random.randint(key, (1,), 0, test_nimg)[0]
np.random.seed(123)
indx = np.random.randint(0, high=maxn)

fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5))
plot.imview(test_ds["label"][indx, ..., 0], title="Ground truth", cbar=None, fig=fig, ax=ax[0])
Expand Down Expand Up @@ -306,14 +313,14 @@


"""
Plot convergence statistics. Statistics only generated if a training
cycle was done (i.e. not reading final epoch results from checkpoint).
Plot convergence statistics. Statistics are generated only if a training
cycle was done (i.e. if not reading final epoch results from checkpoint).
"""
if stats_object is not None:
if stats_object is not None and len(stats_object.iterations) > 0:
hist = stats_object.history(transpose=True)
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))
plot.plot(
jnp.vstack((hist.Train_Loss, hist.Eval_Loss)).T,
np.vstack((hist.Train_Loss, hist.Eval_Loss)).T,
x=hist.Epoch,
ptyp="semilogy",
title="Loss function",
Expand All @@ -324,7 +331,7 @@
ax=ax[0],
)
plot.plot(
jnp.vstack((hist.Train_SNR, hist.Eval_SNR)).T,
np.vstack((hist.Train_SNR, hist.Eval_SNR)).T,
x=hist.Epoch,
title="Metric",
xlbl="Epoch",
Expand All @@ -336,11 +343,11 @@
fig.show()

# Stats for initialization loop
if stats_object_ini is not None:
if stats_object_ini is not None and len(stats_object_ini.iterations) > 0:
hist = stats_object_ini.history(transpose=True)
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))
plot.plot(
jnp.vstack((hist.Train_Loss, hist.Eval_Loss)).T,
np.vstack((hist.Train_Loss, hist.Eval_Loss)).T,
x=hist.Epoch,
ptyp="semilogy",
title="Loss function - Initialization",
Expand All @@ -351,7 +358,7 @@
ax=ax[0],
)
plot.plot(
jnp.vstack((hist.Train_SNR, hist.Eval_SNR)).T,
np.vstack((hist.Train_SNR, hist.Eval_SNR)).T,
x=hist.Epoch,
title="Metric - Initialization",
xlbl="Epoch",
Expand Down
Loading

0 comments on commit e378d7a

Please sign in to comment.