Skip to content

Commit

Permalink
Streamlined training with new Trainer class (#77)
Browse files Browse the repository at this point in the history
* move utility function outside of script

* New trainer class for training reconstruction

* Update docstring

* Update changelog

* Update to trainer save

* Fix partial mask support bug

* Fix docstrings.

* Fix APGD rendering.

---------

Co-authored-by: Eric Bezzam <[email protected]>
  • Loading branch information
YohannPerron and ebezzam authored Aug 29, 2023
1 parent bba42c0 commit 58f747a
Show file tree
Hide file tree
Showing 10 changed files with 453 additions and 233 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Added
- New dataset for pair of original image and their measurement from a screen. See ``utils.dataset.MeasuredDataset`` and ``utils.dataset.MeasuredDatasetSimulatedOriginal``.
- Support for unrolled loading and inference in the script ``admm.py``.
- Tikhonov reconstruction for coded aperture measurements (MLS / MURA).
- New ``Trainer`` class to train ``TrainableReconstructionAlgorithm`` with PyTorch.


Changed
Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ numpy>=1.22 # so that default dtype are correctly rendered
torch>=1.10
torchvision>=0.15.2
torchmetrics>=0.11.4
pyFFS>=2.2.3 # for waveprop
waveprop>=0.0.5
4 changes: 4 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
"torchmetrics.image",
"scipy.ndimage",
"pycsou.abc",
"pycsou.operator",
"pycsou.operator.func",
"pycsou.operator.linop",
"pycsou.opt.solver",
"pycsou.opt.stop",
"pycsou.runtime",
Expand All @@ -33,6 +35,8 @@
"paramiko",
"paramiko.ssh_exception",
"perlin_numpy",
"hydra",
"hydra.utils",
"scipy.special",
"matplotlib.cm",
"pyffs",
Expand Down
22 changes: 20 additions & 2 deletions docs/source/reconstruction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
Accelerated Proximal Gradient Descent (APGD)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: lensless.APGD
.. autoclass:: lensless.recon.apgd.APGD
:special-members: __init__


Expand Down Expand Up @@ -88,4 +88,22 @@
.. autoclass:: lensless.UnrolledADMM
:members: batch_call
:special-members: __init__
:show-inheritance:
:show-inheritance:


Reconstruction Utilities
------------------------

.. autoclass:: lensless.recon.utils.Trainer
:members:
:special-members: __init__

.. autofunction:: lensless.recon.utils.load_drunet

.. autofunction:: lensless.recon.utils.apply_denoiser

.. autofunction:: lensless.recon.utils.get_drunet_function

.. autofunction:: lensless.recon.utils.measure_gradient

.. autofunction:: lensless.recon.utils.create_process_network
2 changes: 0 additions & 2 deletions lensless/eval/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from lensless.utils.dataset import DiffuserCamTestDataset
from tqdm import tqdm

from lensless.utils.io import load_image

try:
import torch
from torch.utils.data import DataLoader
Expand Down
Loading

0 comments on commit 58f747a

Please sign in to comment.