diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index feef375a..eaace9a8 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -409,15 +409,16 @@ def train_unrolled(config): dataset.psf = dataset.psf.to(device) log.info(f"Data shape : {dataset[0][0].shape}") + if config.files.n_files is not None: + dataset = Subset(dataset, np.arange(config.files.n_files)) + dataset.psf = dataset.dataset.psf + # train-test split train_size = int((1 - config.files.test_size) * len(dataset)) test_size = len(dataset) - train_size train_set, test_set = torch.utils.data.random_split( dataset, [train_size, test_size], generator=generator ) - if config.files.n_files is not None: - train_set = Subset(train_set, np.arange(config.files.n_files)) - test_set = Subset(test_set, np.arange(config.files.n_files)) # -- if learning mask downsample = config.files.downsample * 4 # measured files are 4x downsampled @@ -470,7 +471,6 @@ def train_unrolled(config): for i, _idx in enumerate(config.test_idx): - # lensless, lensed = dataset[_idx] lensless, lensed = test_set[_idx] recon = ADMM(psf) diff --git a/scripts/sim/digicam_psf.py b/scripts/sim/digicam_psf.py index 9b665c27..d68d35be 100644 --- a/scripts/sim/digicam_psf.py +++ b/scripts/sim/digicam_psf.py @@ -8,7 +8,6 @@ from slm_controller import slm from lensless.utils.io import save_image, get_dtype, load_psf from lensless.utils.plot import plot_image -from lensless.utils.image import gamma_correction from lensless.hardware.sensor import VirtualSensor from lensless.hardware.slm import get_programmable_mask, get_intensity_psf from waveprop.devices import slm_dict @@ -41,6 +40,9 @@ def digicam_psf(config): Load pattern """ pattern = np.load(fp) + # - make random pattern like original + # pattern = np.random.rand(*pattern.shape) * 255 + # pattern = pattern.astype(np.uint8) # -- apply aperture aperture = np.zeros(pattern.shape, dtype=np.uint8) @@ -58,6 +60,7 @@ def digicam_psf(config): idx_1 : idx_1 + ap_shape[0], idx_2 : idx_2 + ap_shape[1], ] + print("Controllable region shape: ", pattern_sub.shape) print("Total number of pixels: ", np.prod(pattern_sub.shape))