Skip to content

Commit

Permalink
Small changes to scripts. (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam authored Nov 29, 2023
1 parent 4f59f0b commit 8e3747b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
8 changes: 4 additions & 4 deletions scripts/recon/train_unrolled.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,15 +409,16 @@ def train_unrolled(config):
dataset.psf = dataset.psf.to(device)
log.info(f"Data shape : {dataset[0][0].shape}")

if config.files.n_files is not None:
dataset = Subset(dataset, np.arange(config.files.n_files))
dataset.psf = dataset.dataset.psf

# train-test split
train_size = int((1 - config.files.test_size) * len(dataset))
test_size = len(dataset) - train_size
train_set, test_set = torch.utils.data.random_split(
dataset, [train_size, test_size], generator=generator
)
if config.files.n_files is not None:
train_set = Subset(train_set, np.arange(config.files.n_files))
test_set = Subset(test_set, np.arange(config.files.n_files))

# -- if learning mask
downsample = config.files.downsample * 4 # measured files are 4x downsampled
Expand Down Expand Up @@ -470,7 +471,6 @@ def train_unrolled(config):

for i, _idx in enumerate(config.test_idx):

# lensless, lensed = dataset[_idx]
lensless, lensed = test_set[_idx]
recon = ADMM(psf)

Expand Down
5 changes: 4 additions & 1 deletion scripts/sim/digicam_psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from slm_controller import slm
from lensless.utils.io import save_image, get_dtype, load_psf
from lensless.utils.plot import plot_image
from lensless.utils.image import gamma_correction
from lensless.hardware.sensor import VirtualSensor
from lensless.hardware.slm import get_programmable_mask, get_intensity_psf
from waveprop.devices import slm_dict
Expand Down Expand Up @@ -41,6 +40,9 @@ def digicam_psf(config):
Load pattern
"""
pattern = np.load(fp)
# - make random pattern like original
# pattern = np.random.rand(*pattern.shape) * 255
# pattern = pattern.astype(np.uint8)

# -- apply aperture
aperture = np.zeros(pattern.shape, dtype=np.uint8)
Expand All @@ -58,6 +60,7 @@ def digicam_psf(config):
idx_1 : idx_1 + ap_shape[0],
idx_2 : idx_2 + ap_shape[1],
]

print("Controllable region shape: ", pattern_sub.shape)
print("Total number of pixels: ", np.prod(pattern_sub.shape))

Expand Down

0 comments on commit 8e3747b

Please sign in to comment.