Skip to content

Commit

Permalink
Support for dataset with measured background and reconstruction appro…
Browse files Browse the repository at this point in the history
…aches that use it (#142)

* Now supports a dataset that also has background images

* -Fixed PR

-TODO: find alignement factor

* Remove HFDataset inheritance from DualDataset.

* Fix get item.

* Update alignment parameters.

* Reformat.

* Add flag for direct background subtraction.

* Add flag for learned background subtraction.

* Fix missing default.

* Add background subtraction network creation.

* Add integrated subtraction, and concat background to UNetRes input.

* Make background subtraction backward compatible.

* Fix downsample factor.

* Clamp image after subtraction.

* Add option to remove background for benchmarking.

* Add background support to benchmark/eval.

* Update upload config.

* Clamp after background subtraction.

* Update changelog.

---------

Co-authored-by: Eric Bezzam <[email protected]>
  • Loading branch information
StefanPetersTM and ebezzam authored Sep 12, 2024
1 parent 96b292f commit 3ea395f
Show file tree
Hide file tree
Showing 24 changed files with 1,125 additions and 76 deletions.
11 changes: 9 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ Unreleased
Added
~~~~~

- Option to pass background image to ``utils.io.load_data``.
- Option to set image resolution with ``hardware.utils.display`` function.
- Add utility for mask adapter generation in ``lenseless.hardware.fabrication``
- Option to add simulated background in ``util.dataset``
Expand All @@ -28,11 +27,19 @@ Added
- HFSimulated object for simulating lensless data from ground-truth and PSF.
- Option to set cache directory for Hugging Face datasets.
- Option to initialize training with another model.
- Option to pass background image to ``utils.io.load_data``.
- Option to use background in ``lensless.eval.benchmark``.
- Different techniques to use measured background: direct subtraction, learned subtraction, integrated subtraction, concatenated to input.
- Learnable background subtraction for classes that derive from ``lensless.recon.trainable_recon.TrainableReconstructionAlgorithm``.
- Integrated background subtraction object ``lensless.recon.integrated_background.IntegratedBackgroundSub``.
- Option to concatenate background to input to pre-processor.
- Add support for datasets with measured background to ``lensless.utils.dataset.HFDataset``.


Changed
~~~~~~~

- Nothing
- ``lensless.utils.dataset.HFDataset`` no longer inherits from ``lensless.utils.dataset.DualDataset``.

Bugfix
~~~~~~
Expand Down
1 change: 1 addition & 0 deletions configs/benchmark.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ huggingface:
downsample_lensed: 1
split_seed: null
single_channel_psf: False
use_background: True

device: "cuda"
# numbers of iterations to benchmark
Expand Down
45 changes: 45 additions & 0 deletions configs/benchmark_multilens_mirflickr_ambient.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# python scripts/eval/benchmark_recon.py -cn benchmark_multilens_mirflickr_ambient
defaults:
- benchmark
- _self_

dataset: HFDataset
batchsize: 8
device: "cuda:0"

huggingface:
repo: Lensless/MultiLens-Mirflickr-Ambient
cache_dir: /dev/shm
psf: psf.png
image_res: [600, 600] # used during measurement
rotate: False # if measurement is upside-down
alignment:
top_left: [118, 220] # height, width
height: 123
use_background: True

## -- reconstructions trained with same dataset/system
algorithms: [
"ADMM",
"hf:multilens:mirflickr_ambient:U5+Unet8M",
"hf:multilens:mirflickr_ambient:U5+Unet8M_direct_sub",
"hf:multilens:mirflickr_ambient:U5+Unet8M_learned_sub",
"hf:multilens:mirflickr_ambient:Unet4M+U5+Unet4M",
"hf:multilens:mirflickr_ambient:Unet4M+U5+Unet4M_direct_sub",
"hf:multilens:mirflickr_ambient:Unet4M+U5+Unet4M_learned_sub",
"hf:multilens:mirflickr_ambient:Unet4M+U5+Unet4M_concat",
"hf:multilens:mirflickr_ambient:TrainInv+Unet8M",
"hf:multilens:mirflickr_ambient:TrainInv+Unet8M_learned_sub",
"hf:multilens:mirflickr_ambient:Unet4M+TrainInv+Unet4M",
"hf:multilens:mirflickr_ambient:Unet4M+TrainInv+Unet4M_learned_sub",
"hf:multilens:mirflickr_ambient:Unet4M+TrainInv+Unet4M_concat",
"hf:multilens:mirflickr_ambient:TrainInv+Unet8M_direct_sub",
"hf:multilens:mirflickr_ambient:Unet4M+TrainInv+Unet4M_direct_sub",
]

save_idx: [
1, 2, 4, 5, 9, 64, # bottom right
2141, 2155, 2162, 2225, 2502, 2602, # top right (door, flower, cookies, wolf, plush, sky)
3262, 3304, 3438, 3451, 3644, 3667 # bottom left (pancakes, flower, grapes, pencils, bird, sign)
]
n_iter_range: [100] # for ADMM
51 changes: 51 additions & 0 deletions configs/recon_multilens_ambient_mirflickr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# python scripts/recon/multilens_ambient_mirflickr.py
defaults:
- defaults_recon
- _self_

cache_dir: /dev/shm

## - Uncomment to reconstruct from dataset (screen capture)
idx: 1 # index from test set to reconstruct
fn: null # if not null, set local path or download this file from https://huggingface.co/datasets/Lensless/MultiLens-Mirflickr-Ambient/tree/main
background_fn: null

## - Uncomment to reconstruct plush parrot (direct capture)
# fn: parrot_raw.png
# background_fn: parrot_background.png
# rotate: False
# alignment:
# dim: [160, 160]
# top_left: [110, 200]

## - Uncomment to reconstruct plush monkey (direct capture)
# fn: monkey_raw.png
# background_fn: monkey_background.png
# rotate: False
# alignment:
# dim: [123, 123]
# top_left: [118, 220]

## - Uncomment to reconstruct plant (direct capture)
# fn: plant_raw.png
# background_fn: plant_background.png
# rotate: False
# alignment:
# dim: [200, 200]
# top_left: [60, 186]

## Reconstruction
background_sub: True # whether to subtract background

# -- for learning-based methods (uncommment one line)
model: Unet4M+U5+Unet4M_concat
# model: U5+Unet8M
# model: Unet4M+U5+Unet4M_learned_sub

# # -- for ADMM with fixed parameters (uncomment and comment learning-based methods)
# model: admm
# n_iter: 100

device: cuda:0
n_trials: 1 # to get average inference time
save: True
24 changes: 24 additions & 0 deletions configs/train_mirflickr_multilens_ambient.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# python scripts/recon/train_learning_based.py -cn train_mirflickr_multilens_ambient
defaults:
- train_mirflickr_tape
- _self_

wandb_project: multilens_ambient

# Dataset
files:
dataset: Lensless/MultiLens-Mirflickr-Ambient
cache_dir: /dev/shm
image_res: [600, 600]

reconstruction:
direct_background_subtraction: True

alignment:
# when there is no downsampling
top_left: [118, 220] # height, width
height: 123

optimizer:
type: AdamW
cosine_decay_warmup: True
24 changes: 24 additions & 0 deletions configs/train_mirflickr_tape_ambient.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# python scripts/recon/train_learning_based.py -cn train_mirflickr_tape_ambient
defaults:
- train_mirflickr_tape
- _self_

wandb_project: tapecam_ambient
device_ids:

# Dataset
files:
dataset: Lensless/TapeCam-Mirflickr-Ambient
image_res: [600, 600]

reconstruction:
direct_background_subtraction: True

alignment:
# when there is no downsampling
top_left: [85, 185] # height, width
height: 178

optimizer:
type: AdamW
cosine_decay_warmup: True
33 changes: 33 additions & 0 deletions configs/train_mirflickr_tape_ambient_integrated_sub.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# python scripts/recon/train_learning_based.py -cn train_mirflickr_tape_ambient_integrated_sub
defaults:
- train_mirflickr_tape
- _self_

wandb_project: tapecam_ambient
device_ids: [0, 1, 2, 3]
torch_device: cuda:0

# Dataset
files:
dataset: Lensless/TapeCam-Mirflickr-Ambient # 16K examples
cache_dir: /dev/shm
#dataset: Lensless/TapeCam-Mirflickr-Ambient-100 # 100 examples
image_res: [600, 600]

reconstruction:
# one or the other
direct_background_subtraction: False
learned_background_subtraction: False
integrated_background_subtraction: [32, 64, 128, 210, 210]
down_subtraction: False
pre_process:
network: null # TODO assert null when integrated_background_subtraction is not False

alignment:
# when there is no downsampling
top_left: [85, 185] # height, width
height: 178

optimizer:
type: AdamW
cosine_decay_warmup: True
36 changes: 36 additions & 0 deletions configs/train_mirflickr_tape_ambient_learned_sub.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# python scripts/recon/train_learning_based.py -cn train_mirflickr_tape_ambient_learned_sub
defaults:
- train_mirflickr_tape
- _self_

wandb_project: tapecam_ambient
device_ids: [0, 1 ,2, 3]
torch_device: cuda:0

# Dataset
files:
#n_files: 10
dataset: Lensless/TapeCam-Mirflickr-Ambient # 16K examples
#dataset: Lensless/TapeCam-Mirflickr-Ambient-100 # 100 examples
cache_dir: /dev/shm
image_res: [600, 600]

reconstruction:
# one or the other
direct_background_subtraction: False
learned_background_subtraction: [4, 8, 16, 32] # 127740 parameters, False to turn off
integrated_background_subtraction: False

pre_process: ## Targeting 3923428 parameters
network : UnetRes # UnetRes or DruNet or null
depth : 4 # depth of each up/downsampling layer. Ignore if network is DruNet
nc: [32,64,112,128]

alignment:
# when there is no downsampling
top_left: [85, 185] # height, width
height: 178

optimizer:
type: AdamW
cosine_decay_warmup: True
2 changes: 1 addition & 1 deletion configs/train_tapecam_simulated_background.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# python scripts/recon/train_learning_based.py -cn train_mirflickr_tape
# python scripts/recon/train_learning_based.py -cn train_tapecam_simulated_background
defaults:
- train_mirflickr_tape
- _self_
Expand Down
14 changes: 11 additions & 3 deletions configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,15 @@ reconstruction:
init_pre: True # if `init_processors`, set pre-procesor is available
init_post: True # if `init_processors`, set post-procesor is available

# background subtraction (if dataset has corresponding background images)
direct_background_subtraction: False # True or False
learned_background_subtraction: False # False, or set number of channels for UnetRes, e.g. [8,16,32,64]
integrated_background_subtraction: False # False, or set number of channels for UnetRes, e.g. [8,16,32,64]
down_subtraction: False # for integrated_background_subtraction, whether to concatenate background subtraction during downsample or upsample
integrated_background_unetres: False # whether to integrate within UNetRes
unetres_input_background: False # whether to input background to UNetRes


# Hyperparameters for each method
unrolled_fista: # for unrolled_fista
# Number of iterations
Expand Down Expand Up @@ -181,18 +190,17 @@ training:
crop_preloss: False # crop region for computing loss, files.crop should be set

optimizer:
type: Adam # Adam, SGD... (Pytorch class)
type: AdamW # Adam, SGD... (Pytorch class)
lr: 1e-4
lr_step_epoch: True # True -> update LR at end of each epoch, False at the end of each mini-batch
final_lr: False # if set, exponentially decay *to* this value
exp_decay: False # if set, exponentially decay *with* this value
slow_start: False #float how much to reduce lr for first epoch
cosine_decay_warmup: False # if set, cosine decay with warmup of 5%
cosine_decay_warmup: True # if set, cosine decay with warmup of 5%
# Decay LR in step fashion: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html
step: False # int, period of learning rate decay. False to not apply
gamma: 0.1 # float, factor for learning rate decay


loss: 'l2'
# set lpips to false to deactivate. Otherwise, give the weigth for the loss (the main loss l2/l1 always having a weigth of 1)
lpips: 1.0
Expand Down
6 changes: 5 additions & 1 deletion configs/upload_multilens_mirflickr_ambient.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@ defaults:
repo_id: "Lensless/MultiLens-Mirflickr-Ambient"
n_files:
test_size: 0.15

# # -- to match TapeCam dataset content distribution, and same light distribution in train/test
# split: 100 # "first: first `nfiles*test_size` for test, `int`: test_size*split for test (interleaved) as if multimask with this many masks

lensless:
dir: /dev/shm/all_measured_20240813-183259
ambient: True
ext: ".png"

lensed:
dir: data/mirflickr/mirflickr
dir: /root/LenslessPiCam/data/mirflickr/mirflickr
ext: ".jpg"

files:
Expand Down
9 changes: 8 additions & 1 deletion lensless/eval/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def benchmark(
use_wandb=False,
label=None,
epoch=None,
use_background=True,
**kwargs,
):
"""
Expand Down Expand Up @@ -69,6 +70,8 @@ def benchmark(
If True, return the average value of the metrics, by default True.
snr : float, optional
Signal to noise ratio for adding shot noise. If None, no noise is added, by default None.
use_background: bool, optional
If dataset has background, use it for reconstruction, by default True.
Returns
-------
Expand Down Expand Up @@ -121,8 +124,11 @@ def benchmark(

flip_lr = None
flip_ud = None
background = None
lensless = batch[0].to(device)
lensed = batch[1].to(device)
if dataset.measured_bg and use_background:
background = batch[-1].to(device)
if dataset.multimask or dataset.random_flip:
psfs = batch[2]
psfs = psfs.to(device)
Expand All @@ -146,11 +152,12 @@ def benchmark(
plot=False,
save=False,
output_intermediate=unrolled_output_factor or pre_process_aux,
background=background,
**kwargs,
)

else:
prediction = model.forward(lensless, psfs, **kwargs)
prediction = model.forward(lensless, psfs, background=background, **kwargs)

if unrolled_output_factor or pre_process_aux:
pre_process_out = prediction[2]
Expand Down
Loading

0 comments on commit 3ea395f

Please sign in to comment.