diff --git a/.gitignore b/.gitignore index 3430ec87..ca9b0dfa 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ data/* models/* *.png *.jpg +*.npy configs/telegram_demo_secret.yaml diff --git a/configs/adafruit.yaml b/configs/adafruit.yaml new file mode 100644 index 00000000..5399204e --- /dev/null +++ b/configs/adafruit.yaml @@ -0,0 +1,9 @@ +defaults: + - demo + - _self_ + +plot: True + +capture: + exp: 5.0 + awb_gains: [1, 1] diff --git a/configs/apgd_l1.yaml b/configs/apgd_l1.yaml index 5d0621cc..006b72aa 100644 --- a/configs/apgd_l1.yaml +++ b/configs/apgd_l1.yaml @@ -3,6 +3,10 @@ defaults: - defaults_recon - _self_ +preprocess: + # Downsampling factor along X and Y + downsample: 8 + apgd: # Proximal prior / regularization: nonneg, l1, null prox_penalty: l1 diff --git a/configs/apgd_l2.yaml b/configs/apgd_l2.yaml index 65a16405..0b50ba73 100644 --- a/configs/apgd_l2.yaml +++ b/configs/apgd_l2.yaml @@ -3,6 +3,10 @@ defaults: - defaults_recon - _self_ +preprocess: + # Downsampling factor along X and Y + downsample: 8 + apgd: diff_penalty: l2 diff_lambda: 0.0001 diff --git a/configs/defaults_recon.yaml b/configs/defaults_recon.yaml index 5cd05d6c..324aa679 100644 --- a/configs/defaults_recon.yaml +++ b/configs/defaults_recon.yaml @@ -27,6 +27,7 @@ preprocess: single_psf: False # Whether to perform construction in grayscale. gray: False + bg_pix: [5, 25] # null to skip display: diff --git a/configs/demo.yaml b/configs/demo.yaml index c769d1a2..ddc0c528 100644 --- a/configs/demo.yaml +++ b/configs/demo.yaml @@ -26,6 +26,8 @@ display: psf: null # all black screen black: False + # all white screen + white: False capture: gamma: null # for visualization diff --git a/configs/digicam.yaml b/configs/digicam.yaml new file mode 100644 index 00000000..d84b3a89 --- /dev/null +++ b/configs/digicam.yaml @@ -0,0 +1,23 @@ +rpi: + username: null + hostname: null + +device: adafruit +virtual: False +save: True + +# pattern: data/psf/adafruit_random_pattern_20230719.npy +pattern: random +# pattern: rect +# pattern: circ +min_val: 0 # if pattern: random, min for range(0,1) +rect_shape: [20, 10] # if pattern: rect +radius: 20 # if pattern: circ +center: [0, 0] + + +aperture: + center: [59,76] + shape: [19,26] + +z: 4 # mask to sensor distance diff --git a/configs/recon_dataset.yaml b/configs/recon_dataset.yaml new file mode 100644 index 00000000..f474aed5 --- /dev/null +++ b/configs/recon_dataset.yaml @@ -0,0 +1,47 @@ +# python scripts/recon/dataset.py +defaults: + - defaults_recon + - _self_ + +torch: True +torch_device: 'cuda:0' + +input: + # https://drive.switch.ch/index.php/s/NdgHlcDeHVDH5ww?path=%2Fpsf + psf: data/psf/adafruit_random_2mm_20231907.png + # https://drive.switch.ch/index.php/s/m89D1tFEfktQueS + raw_data: data/celeba_adafruit_random_2mm_20230720_1K + +n_files: 25 # null for all files +output_folder: data/celeba_adafruit_recon + +# extraction region of interest +roi: null # top, left, bottom, right +# -- values for `data/celeba_adafruit_random_2mm_20230720_1K` +# roi: [10, 300, 560, 705] # down 4 +# roi: [6, 200, 373, 470] # down 6 +# roi: [5, 150, 280, 352] # down 8 + +preprocess: + flip: True + downsample: 6 + + # to have different data shape than PSF + data_dim: null + # data_dim: [48, 64] # down 64 + # data_dim: [506, 676] # down 6 + +display: + disp: -1 + plot: False + +algo: admm # "admm", "apgd", "null" to just copy over (resized) raw data + +apgd: + n_jobs: 1 # run in parallel as algo is slow + max_iter: 500 + +admm: + n_iter: 10 + +save: False \ No newline at end of file diff --git a/configs/sim_digicam_psf.yaml b/configs/sim_digicam_psf.yaml new file mode 100644 index 00000000..216455cd --- /dev/null +++ b/configs/sim_digicam_psf.yaml @@ -0,0 +1,38 @@ +# python scripts/sim/digicam_psf.py +hydra: + job: + chdir: True # change to output folder + +use_torch: False +dtype: float32 +torch_device: cuda +requires_grad: True + +digicam: + + slm: adafruit + sensor: rpi_hq + + # https://drive.switch.ch/index.php/s/NdgHlcDeHVDH5ww?path=%2Fpsf + pattern: data/psf/adafruit_random_pattern_20230719.npy + ap_center: [59, 76] + ap_shape: [19, 26] + rotate: -0.8 # rotation in degrees + + # optionally provide measured PSF for side-by-side comparison + # https://drive.switch.ch/index.php/s/NdgHlcDeHVDH5ww?path=%2Fpsf + psf: data/psf/adafruit_random_2mm_20231907.png + gamma: 2 # for plotting measured + +sim: + + # whether SLM is fliped + flipud: True + + # in practice found waveprop=True or False doesn't make difference + waveprop: False + + # below are ignored if waveprop=False + scene2mask: 0.03 # [m] + mask2sensor: 0.002 # [m] + \ No newline at end of file diff --git a/configs/train_celeba_classifier.yaml b/configs/train_celeba_classifier.yaml new file mode 100644 index 00000000..11a391c8 --- /dev/null +++ b/configs/train_celeba_classifier.yaml @@ -0,0 +1,38 @@ +hydra: + job: + chdir: True # change to output folder + +seed: 0 + +data: + # -- path to original CelebA (parent directory) + original: /scratch/bezzam + + output_dir: "./vit-celeba" # basename for model output + + # -- raw + # https://drive.switch.ch/index.php/s/m89D1tFEfktQueS + measured: data/celeba_adafruit_random_2mm_20230720_10K + raw: True + + # # -- reconstructed + # # run `python scripts/recon/dataset.py` to get a reconstructed dataset + # measured: null + # raw: False + + n_files: null # null to use all in measured_folder + test_size: 0.15 + attr: Male # "Male", "Smiling", etc + +augmentation: + + random_resize_crop: False + horizontal_flip: True # cannot be used with raw measurement! + +train: + + prev: null # path to previously trained model + n_epochs: 4 + dropout: 0.1 + batch_size: 16 + learning_rate: 2e-4 diff --git a/digicam_requirements.txt b/digicam_requirements.txt new file mode 100644 index 00000000..fbbcaa30 --- /dev/null +++ b/digicam_requirements.txt @@ -0,0 +1 @@ +slm_controller @ git+https://github.com/ebezzam/slm-controller.git \ No newline at end of file diff --git a/docs/source/data.rst b/docs/source/data.rst index 768b46fb..50b323c6 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -39,6 +39,20 @@ use the correct PSF file for the data you're using! input.psf=data/psf/tape_rgb.png +Measured CelebA Dataset +----------------------- + +You can download 1K measurements of the CelebA dataset done with +our lensless camera and a random pattern on the Adafruit LCD +`here (1.2 GB) `__, +and a dataset with 10K measurements +`here (13.1 GB) `__. +They both correspond to the PSF which can be found `here `__ +(``adafruit_random_2mm_20231907.png`` which is the PSF of +``adafruit_random_pattern_20230719.npy`` measured with a mask to sensor +distance of 2 mm). + + DiffuserCam Lensless Mirflickr Dataset (DLMD) --------------------------------------------- diff --git a/lensless/hardware/aperture.py b/lensless/hardware/aperture.py new file mode 100644 index 00000000..37e8e37b --- /dev/null +++ b/lensless/hardware/aperture.py @@ -0,0 +1,379 @@ +# ############################################################################# +# aperture.py +# ================= +# Authors : +# Eric BEZZAM [ebezzam@gmail.com] +# ############################################################################# + + +from enum import Enum + +import numpy as np +from lensless.utils.image import rgb2gray + + +class ApertureOptions(Enum): + RECT = "rect" + SQUARE = "square" + LINE = "line" + CIRC = "circ" + + @staticmethod + def values(): + return [shape.value for shape in ApertureOptions] + + +class Aperture: + def __init__(self, shape, pixel_pitch): + """ + Class for defining VirtualSLM. + + :param shape: (height, width) in number of cell. + :type shape: tuple(int) + :param pixel_pitch: Pixel pitch (height, width) in meters. + :type pixel_pitch: tuple(float) + """ + assert np.all(shape) > 0 + assert np.all(pixel_pitch) > 0 + self._shape = shape + self._pixel_pitch = pixel_pitch + self._values = np.zeros((3,) + shape, dtype=np.uint8) + + @property + def size(self): + return np.prod(self._shape) + + @property + def shape(self): + return self._shape + + @property + def pixel_pitch(self): + return self._pixel_pitch + + @property + def center(self): + return np.array([self.height / 2, self.width / 2]) + + @property + def dim(self): + return np.array(self._shape) * np.array(self._pixel_pitch) + + @property + def height(self): + return self.dim[0] + + @property + def width(self): + return self.dim[1] + + @property + def values(self): + return self._values + + @property + def grayscale_values(self): + return rgb2gray(self._values) + + def at(self, physical_coord, value=None): + """ + Get/set values of VirtualSLM at physical coordinate in meters. + + :param physical_coord: Physical coordinates to get/set VirtualSLM values. + :type physical_coord: int, float, slice tuples + :param value: [Optional] values to set, otherwise return values at + specified coordinates. Defaults to None + :type value: int, float, :py:class:`~numpy.ndarray`, optional + :return: If getter is used, values at those coordinates + :rtype: ndarray + """ + idx = prepare_index_vals(physical_coord, self._pixel_pitch) + if value is None: + # getter + return self._values[idx] + else: + # setter + self._values[idx] = value + + def __getitem__(self, key): + return self._values[key] + + def __setitem__(self, key, value): + self._values[key] = value + + def plot(self, show_tick_labels=False): + """ + Plot Aperture. + + :param show_tick_labels: Whether to show cell number along x- and y-axis, defaults to False + :type show_tick_labels: bool, optional + :return: The axes of the plot. + :rtype: Axes + """ + # prepare mask data for `imshow`, expects the input data array size to be (width, height, 3) + Z = self.values.transpose(1, 2, 0) + + # plot + import matplotlib.pyplot as plt + + _, ax = plt.subplots() + extent = [ + -0.5 * self._pixel_pitch[1], + (self._shape[1] - 0.5) * self._pixel_pitch[1], + (self._shape[0] - 0.5) * self._pixel_pitch[0], + -0.5 * self._pixel_pitch[0], + ] + ax.imshow(Z, extent=extent) + ax.grid(which="major", axis="both", linestyle="-", color="0.5", linewidth=0.25) + + x_ticks = np.arange(-0.5, self._shape[1], 1) * self._pixel_pitch[1] + ax.set_xticks(x_ticks) + if show_tick_labels: + x_tick_labels = (np.arange(-0.5, self._shape[1], 1) + 0.5).astype(int) + else: + x_tick_labels = [None] * len(x_ticks) + ax.set_xticklabels(x_tick_labels) + + y_ticks = np.arange(-0.5, self._shape[0], 1) * self._pixel_pitch[0] + ax.set_yticks(y_ticks) + if show_tick_labels: + y_tick_labels = (np.arange(-0.5, self._shape[0], 1) + 0.5).astype(int) + else: + y_tick_labels = [None] * len(y_ticks) + ax.set_yticklabels(y_tick_labels) + return ax + + +def rect_aperture(slm_shape, pixel_pitch, apert_dim, center=None): + """ + Create and return VirtualSLM object with rectangular aperture of desired dimensions. + + :param slm_shape: Dimensions (height, width) of VirtualSLM in cells. + :type slm_shape: tuple(int) + :param pixel_pitch: Dimensions (height, width) of each cell in meters. + :type pixel_pitch: tuple(float) + :param apert_dim: Dimensions (height, width) of aperture in meters. + :type apert_dim: tuple(float) + :param center: [Optional] center of aperture along (SLM) coordinates, indexing starts in top-left corner. + Default behavior is to place center of aperture at center of SLM. + Defaults to None + :type center: tuple(float), optional + :raises ValueError: If aperture does extend over the boarder of the SLM. + :return: VirtualSLM object with cells programmed to desired rectangular aperture. + :rtype: :py:class:`~mask_designer.virtual_slm.VirtualSLM` + """ + # check input values + assert np.all(apert_dim) > 0 + + # initialize SLM + slm = Aperture(shape=slm_shape, pixel_pitch=pixel_pitch) + + # check / compute center + if center is None: + center = slm.center + else: + assert ( + 0 <= center[0] < slm.height + ), f"Center {center} must lie within VirtualSLM dimensions {slm.dim}." + assert ( + 0 <= center[1] < slm.width + ), f"Center {center} must lie within VirtualSLM dimensions {slm.dim}." + + # compute mask + apert_dim = np.array(apert_dim) + top_left = center - apert_dim / 2 + bottom_right = top_left + apert_dim + if ( + top_left[0] < 0 + or top_left[1] < 0 + or bottom_right[0] >= slm.dim[0] + or bottom_right[1] >= slm.dim[1] + ): + raise ValueError( + f"Aperture ({top_left[0]}:{bottom_right[0]}, " + f"{top_left[1]}:{bottom_right[1]}) extends past valid " + f"VirtualSLM dimensions {slm.dim}" + ) + slm.at( + physical_coord=np.s_[top_left[0] : bottom_right[0], top_left[1] : bottom_right[1]], + value=255, + ) + + return slm + + +def line_aperture(slm_shape, pixel_pitch, length, vertical=True, center=None): + """ + Create and return VirtualSLM object with a line aperture of desired length. + + :param slm_shape: Dimensions (height, width) of VirtualSLM in cells. + :type slm_shape: tuple(int) + :param pixel_pitch: Dimensions (height, width) of each cell in meters. + :type pixel_pitch: tuple(float) + :param length: Length of aperture in meters. + :type length: float + :param vertical: Orient line vertically, defaults to True. + :type vertical: bool, optional + :param center: [Optional] center of aperture along (SLM) coordinates, indexing starts in top-left corner. + Default behavior is to place center of aperture at center of SLM. + Defaults to None + :type center: tuple(float), optional + :return: VirtualSLM object with cells programmed to desired line aperture. + :rtype: :py:class:`~mask_designer.virtual_slm.VirtualSLM` + """ + # call `create_rect_aperture` + apert_dim = (length, pixel_pitch[1]) if vertical else (pixel_pitch[0], length) + return rect_aperture(slm_shape, pixel_pitch, apert_dim, center) + + +def square_aperture(slm_shape, pixel_pitch, side, center=None): + """ + Create and return VirtualSLM object with a square aperture of desired shape. + + :param slm_shape: Dimensions (height, width) of VirtualSLM in cells. + :type slm_shape: tuple(int) + :param pixel_pitch: Dimensions (height, width) of each cell in meters. + :type pixel_pitch: tuple(float) + :param side: Side length of square aperture in meters. + :type side: float + :param center: [Optional] center of aperture along (SLM) coordinates, indexing starts in top-left corner. + Default behavior is to place center of aperture at center of SLM. + Defaults to None + :type center: tuple(float), optional + :return: VirtualSLM object with cells programmed to desired square aperture. + :rtype: :py:class:`~mask_designer.virtual_slm.VirtualSLM` + """ + return rect_aperture(slm_shape, pixel_pitch, (side, side), center) + + +def circ_aperture(slm_shape, pixel_pitch, radius, center=None): + """ + Create and return VirtualSLM object with a circle aperture of desired shape. + + :param slm_shape: Dimensions (height, width) of VirtualSLM in cells. + :type slm_shape: tuple(int) + :param pixel_pitch: Dimensions (height, width) of each cell in meters. + :type pixel_pitch: tuple(float) + :param radius: Radius of aperture in meters. + :type radius: float + :param center: [Optional] center of aperture along (SLM) coordinates, indexing starts in top-left corner. + Default behavior is to place center of aperture at center of SLM. + Defaults to None + :type center: tuple(float), optional + :return: VirtualSLM object with cells programmed to desired circle aperture. + :rtype: :py:class:`~mask_designer.virtual_slm.VirtualSLM` + """ + # check input values + assert radius > 0 + + # initialize SLM + slm = Aperture(shape=slm_shape, pixel_pitch=pixel_pitch) + + # check / compute center + if center is None: + center = slm.center + else: + assert ( + 0 <= center[0] < slm.height + ), f"Center {center} must lie within VirtualSLM dimensions {slm.dim}." + assert ( + 0 <= center[1] < slm.width + ), f"Center {center} must lie within VirtualSLM dimensions {slm.dim}." + + # compute mask + i, j = np.meshgrid( + np.arange(slm.dim[0], step=slm.pixel_pitch[0]), + np.arange(slm.dim[1], step=slm.pixel_pitch[1]), + sparse=True, + indexing="ij", + ) + x2 = (i - center[0]) ** 2 + y2 = (j - center[1]) ** 2 + slm[:] = 255 * (x2 + y2 < radius**2) + return slm + + +def _cell_slice(_slice, cell_m): + """ + Convert slice indexing in meters to slice indexing in cells. + + author: Eric Bezzam, + email: ebezzam@gmail.com, + GitHub: https://github.com/ebezzam + + :param _slice: Original slice in meters. + :type _slice: slice + :param cell_m: Dimension of cell in meters. + :type cell_m: float + :return: The new slice + :rtype: slice + """ + start = None if _slice.start is None else _m_to_cell_idx(_slice.start, cell_m) + stop = _m_to_cell_idx(_slice.stop, cell_m) if _slice.stop is not None else None + step = _m_to_cell_idx(_slice.step, cell_m) if _slice.step is not None else None + return slice(start, stop, step) + + +def _m_to_cell_idx(val, cell_m): + """ + Convert location to cell index. + + author: Eric Bezzam, + email: ebezzam@gmail.com, + GitHub: https://github.com/ebezzam + + :param val: Location in meters. + :type val: float + :param cell_m: Dimension of cell in meters. + :type cell_m: float + :return: The cell index. + :rtype: int + """ + return int(val / cell_m) + + +def prepare_index_vals(key, pixel_pitch): + """ + Convert indexing object in meters to indexing object in cell indices. + + author: Eric Bezzam, + email: ebezzam@gmail.com, + GitHub: https://github.com/ebezzam + + :param key: Indexing operation in meters. + :type key: int, float, slice, or list + :param pixel_pitch: Pixel pitch (height, width) in meters. + :type pixel_pitch: tuple(float) + :raises ValueError: If the key is of the wrong type. + :raises NotImplementedError: If key is of size 3, individual channels can't + be indexed. + :raises ValueError: If the key has the wrong dimensions. + :return: The new indexing object. + :rtype: tuple[slice, int] | tuple[slice, slice] | tuple[slice, ...] + """ + if isinstance(key, (float, int)): + idx = slice(None), _m_to_cell_idx(key, pixel_pitch[0]) + + elif isinstance(key, slice): + idx = slice(None), _cell_slice(key, pixel_pitch[0]) + + elif len(key) == 2: + idx = [slice(None)] + for k, _slice in enumerate(key): + + if isinstance(_slice, slice): + idx.append(_cell_slice(_slice, pixel_pitch[k])) + + elif isinstance(_slice, (float, int)): + idx.append(_m_to_cell_idx(_slice, pixel_pitch[k])) + + else: + raise ValueError("Invalid key.") + idx = tuple(idx) + + elif len(key) == 3: + raise NotImplementedError("Cannot index individual channels.") + + else: + raise ValueError("Invalid key.") + return idx diff --git a/lensless/hardware/sensor.py b/lensless/hardware/sensor.py index 36a5adda..08a00a05 100644 --- a/lensless/hardware/sensor.py +++ b/lensless/hardware/sensor.py @@ -170,6 +170,8 @@ def __init__( else: self.size = self.pixel_size * self.resolution + self.pitch = self.size / self.resolution + self.image_shape = self.resolution if self.color: self.image_shape = np.append(self.image_shape, 3) @@ -298,6 +300,7 @@ def downsample(self, factor): assert factor > 1, "Downsample factor must be greater than 1." self.pixel_size = self.pixel_size * factor + self.pitch = self.pitch * factor self.resolution = (self.resolution / factor).astype(int) self.size = self.pixel_size * self.resolution self.image_shape = self.resolution diff --git a/lensless/hardware/slm.py b/lensless/hardware/slm.py new file mode 100644 index 00000000..572ae4a7 --- /dev/null +++ b/lensless/hardware/slm.py @@ -0,0 +1,298 @@ +# ############################################################################# +# slm.py +# ================= +# Authors : +# Eric BEZZAM [ebezzam@gmail.com] +# ############################################################################# + + +import os +import numpy as np +from lensless.hardware.utils import check_username_hostname +from lensless.utils.io import get_dtype, get_ctypes +from slm_controller.hardware import SLMParam, slm_devices +from waveprop.spherical import spherical_prop +from waveprop.color import ColorSystem +from waveprop.rs import angular_spectrum +from waveprop.slm import get_centers, get_color_filter +from waveprop.devices import SLMParam as SLMParam_wp +from scipy.ndimage import rotate as rotate_func + + +try: + import torch + from torchvision import transforms + + torch_available = True +except ImportError: + torch_available = False + + +SUPPORTED_DEVICE = { + "adafruit": "~/slm-controller/examples/adafruit_slm.py", + "nokia": "~/slm-controller/examples/nokia_slm.py", + "holoeye": "~/slm-controller/examples/holoeye_slm.py", +} + + +def set_programmable_mask(pattern, device, rpi_username, rpi_hostname): + """ + Set LCD pattern on Raspberry Pi. + + This function assumes that `slm-controller `_ + is installed on the Raspberry Pi. + + Parameters + ---------- + pattern : :py:class:`~numpy.ndarray` + Pattern to set on programmable mask. + device : str + Name of device to set pattern on. Supported devices: "adafruit", "nokia", "holoeye". + rpi_username : str + Username of Raspberry Pi. + rpi_hostname : str + Hostname of Raspberry Pi. + + """ + + client = check_username_hostname(rpi_username, rpi_hostname) + + # get path to python executable on Raspberry Pi + rpi_python = "~/slm-controller/slm_controller_env/bin/python" + assert ( + device in SUPPORTED_DEVICE.keys() + ), f"Device {device} not supported. Supported devices: {SUPPORTED_DEVICE.keys()}" + script = SUPPORTED_DEVICE[device] + + # check that pattern is correct shape + expected_shape = slm_devices[device][SLMParam.SLM_SHAPE] + if not slm_devices[device][SLMParam.MONOCHROME]: + expected_shape = (3, *expected_shape) + assert ( + pattern.shape == expected_shape + ), f"Pattern shape {pattern.shape} does not match expected shape {expected_shape}" + + # save pattern + pattern_fn = "tmp_pattern.npy" + local_path = os.path.join(os.getcwd(), pattern_fn) + np.save(local_path, pattern) + + # copy pattern to Raspberry Pi + remote_path = f"~/{pattern_fn}" + print(f"PUTTING {local_path} to {remote_path}") + + os.system('scp %s "%s@%s:%s" ' % (local_path, rpi_username, rpi_hostname, remote_path)) + # # -- not sure why this doesn't work... permission denied + # sftp = client.open_sftp() + # sftp.put(local_path, remote_path, confirm=True) + # sftp.close() + + # run script on Raspberry Pi to set mask pattern + command = f"{rpi_python} {script} --file_path {remote_path}" + print(f"COMMAND : {command}") + _stdin, _stdout, _stderr = client.exec_command(command) + print(_stdout.read().decode()) + client.close() + + os.remove(local_path) + + +def get_programmable_mask( + vals, + sensor, + slm_param, + rotate=None, + flipud=False, + nbits=8, +): + """ + Get mask as a numpy or torch array. Return same type. + + Parameters + ---------- + vals : :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor` + Values to set on programmable mask. + sensor : :py:class:`~lensless.hardware.sensor.VirtualSensor` + Sensor object. + slm_param : dict + SLM parameters. + rotate : float, optional + Rotation angle in degrees. + flipud : bool, optional + Flip mask vertically. + nbits : int, optional + Number of bits/levels to quantize mask to. + + """ + + use_torch = False + if torch_available: + use_torch = isinstance(vals, torch.Tensor) + dtype = vals.dtype + + # -- prepare SLM mask + n_active_slm_pixels = vals.shape + n_color_filter = np.prod(slm_param["color_filter"].shape[:2]) + pixel_pitch = slm_param[SLMParam_wp.PITCH] + centers = get_centers(n_active_slm_pixels, pixel_pitch=pixel_pitch) + + if SLMParam_wp.COLOR_FILTER in slm_param.keys(): + color_filter = slm_param[SLMParam_wp.COLOR_FILTER] + if flipud: + color_filter = np.flipud(color_filter) + + cf = get_color_filter( + slm_dim=n_active_slm_pixels, + color_filter=color_filter, + shift=0, + flat=True, + ) + + else: + + # monochrome + cf = None + + d1 = sensor.pitch + _height_pixel, _width_pixel = (slm_param[SLMParam_wp.CELL_SIZE] / d1).astype(int) + + if use_torch: + mask = torch.zeros((n_color_filter,) + tuple(sensor.resolution)).to(vals) + slm_vals_flat = vals.flatten() + else: + mask = np.zeros((n_color_filter,) + tuple(sensor.resolution), dtype=dtype) + slm_vals_flat = vals.reshape(-1) + + for i, _center in enumerate(centers): + + _center_pixel = (_center / d1 + sensor.resolution / 2).astype(int) + _center_top_left_pixel = ( + _center_pixel[0] - np.floor(_height_pixel / 2).astype(int), + _center_pixel[1] + 1 - np.floor(_width_pixel / 2).astype(int), + ) + + if cf is not None: + _rect = np.tile(cf[i][:, np.newaxis, np.newaxis], (1, _height_pixel, _width_pixel)) + else: + _rect = np.ones((1, _height_pixel, _width_pixel)) + + if use_torch: + _rect = torch.tensor(_rect).to(slm_vals_flat) + + mask[ + :, + _center_top_left_pixel[0] : _center_top_left_pixel[0] + _height_pixel, + _center_top_left_pixel[1] : _center_top_left_pixel[1] + _width_pixel, + ] = ( + slm_vals_flat[i] * _rect + ) + + # quantize mask + if use_torch: + mask = mask / torch.max(mask) + mask = torch.round(mask * (2**nbits - 1)) / (2**nbits - 1) + else: + mask = mask / np.max(mask) + mask = np.round(mask * (2**nbits - 1)) / (2**nbits - 1) + + # rotate + if rotate is not None: + if use_torch: + mask = transforms.functional.rotate(mask, angle=rotate) + else: + mask = rotate_func(mask, axes=(2, 1), angle=rotate, reshape=False) + + return mask + + +def get_intensity_psf( + mask, + waveprop=False, + sensor=None, + scene2mask=None, + mask2sensor=None, + color_system=None, +): + """ + Get intensity PSF from mask pattern. Return same type of data. + + Parameters + ---------- + mask : :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor` + Mask pattern. + waveprop : bool, optional + Whether to use wave propagation to compute PSF. Default is False, + namely to return squared intensity of mask pattern as the PSF (i.e., + no wave propagation and just shadow of pattern). + sensor : :py:class:`~lensless.hardware.sensor.VirtualSensor` + Sensor object. Not used if ``waveprop=False``. + scene2mask : float + Distance from scene to mask. Not used if ``waveprop=False``. + mask2sensor : float + Distance from mask to sensor. Not used if ``waveprop=False``. + color_system : :py:class:`~waveprop.color.ColorSystem`, optional + Color system. Not used if ``waveprop=False``. + + """ + if color_system is None: + color_system = ColorSystem.rgb() + + is_torch = False + device = None + if torch_available: + is_torch = isinstance(mask, torch.Tensor) + device = mask.device + + dtype = mask.dtype + ctype, _ = get_ctypes(dtype, is_torch) + + if is_torch: + psfs = torch.zeros(mask.shape, dtype=ctype, device=device) + else: + psfs = np.zeros(mask.shape, dtype=ctype) + + if waveprop: + + assert sensor is not None, "sensor must be specified" + assert scene2mask is not None, "scene2mask must be specified" + assert mask2sensor is not None, "mask2sensor must be specified" + + assert ( + len(color_system.wv) == mask.shape[0] + ), "Number of wavelengths must match number of color channels" + + # spherical wavefronts to mask + spherical_wavefront = spherical_prop( + in_shape=sensor.resolution, + d1=sensor.pitch, + wv=color_system.wv, + dz=scene2mask, + return_psf=True, + is_torch=True, + device=device, + dtype=dtype, + ) + u_in = spherical_wavefront * mask + + # free space propagation to sensor + for i, wv in enumerate(color_system.wv): + psfs[i], _, _ = angular_spectrum( + u_in=u_in[i], + wv=wv, + d1=sensor.pitch, + dz=mask2sensor, + dtype=dtype, + device=device, + ) + + else: + + psfs = mask + + # -- intensity PSF + if is_torch: + psf_in = torch.square(torch.abs(psfs)) + else: + psf_in = np.square(np.abs(psfs)) + + return psf_in diff --git a/lensless/hardware/utils.py b/lensless/hardware/utils.py index a0c0d573..97b384f6 100644 --- a/lensless/hardware/utils.py +++ b/lensless/hardware/utils.py @@ -2,6 +2,7 @@ import os import socket import subprocess +import time import paramiko from paramiko.ssh_exception import AuthenticationException, BadHostKeyException, SSHException @@ -65,7 +66,7 @@ def check_username_hostname(username, hostname, timeout=10): except (BadHostKeyException, AuthenticationException, SSHException, socket.error) as e: raise ValueError(f"Could not connect to {username}@{hostname}\n{e}") - return username, hostname + return client def get_distro(): @@ -92,3 +93,62 @@ def get_distro(): # Just major version shown, replace it with the full version RELEASE_DATA["VERSION"] = " ".join([DEBIAN_VERSION] + version_split[1:]) return f"{RELEASE_DATA['NAME']} {RELEASE_DATA['VERSION']}" + + +def set_mask_sensor_distance(distance, rpi_username, rpi_hostname, motor=1): + """ + Set the distance between the mask and sensor. + + This functions assumes that `StepperDriver `_ is installed. + is downloaded on the Raspberry Pi. + + Parameters + ---------- + distance : float + Distance in mm. Positive values move the mask away from the sensor. + rpi_username : str + Username of Raspberry Pi. + rpi_hostname : str + Hostname of Raspberry Pi. + """ + + MAX_DISTANCE = 16 # mm + timeout = 5 + + client = check_username_hostname(rpi_username, rpi_hostname) + assert motor in [0, 1] + assert distance >= 0, "Distance must be non-negative" + assert distance < MAX_DISTANCE, f"Distance must be less than {MAX_DISTANCE} mm" + + # assumes that `StepperDriver` is in home directory + rpi_python = "python3" + script = "~/StepperDriver/Python/serial_motors.py" + + # reset to zero + print("Resetting to zero distance...") + try: + command = f"{rpi_python} {script} {motor} REV {MAX_DISTANCE * 1000}" + _stdin, _stdout, _stderr = client.exec_command(command, timeout=timeout) + except socket.timeout: # socket.timeout + pass + + client.close() + time.sleep(5) # TODO reduce this time + client = check_username_hostname(rpi_username, rpi_hostname) + + # set to desired distance + if distance != 0: + print(f"Setting distance to {distance} mm...") + distance_um = distance * 1000 + if distance_um >= 0: + command = f"{rpi_python} {script} {motor} FWD {distance_um}" + else: + command = f"{rpi_python} {script} {motor} REV {-1 * distance_um}" + print(f"COMMAND : {command}") + try: + _stdin, _stdout, _stderr = client.exec_command(command, timeout=timeout) + print(_stdout.read().decode()) + except socket.timeout: # socket.timeout + client.close() + + client.close() diff --git a/lensless/recon/apgd.py b/lensless/recon/apgd.py index 2ae5a69d..327c32de 100644 --- a/lensless/recon/apgd.py +++ b/lensless/recon/apgd.py @@ -11,7 +11,9 @@ import inspect import numpy as np from typing import Optional +from lensless.utils.image import resize from lensless.recon.rfft_convolve import RealFFTConvolve2D as Convolver +import cv2 import pycsou.abc as pyca import pycsou.operator.func as func @@ -20,6 +22,7 @@ import pycsou.runtime as pycrt import pycsou.util as pycu import pycsou.util.ptype as pyct +import pycsou.operator.linop as pycl class APGDPriors: @@ -95,6 +98,7 @@ def __init__( rel_error=None, lipschitz_tight=True, lipschitz_tol=1.0, + img_shape=None, **kwargs ): """ @@ -132,27 +136,52 @@ def __init__( Whether to use tight Lipschitz constant or not. Default is True. lipschitz_tol : float, optional Tolerance to compute Lipschitz constant. Default is 1. + img_shape : tuple, optional + Shape of measurement (H, W, C). If None, assume shape of PSF. """ assert isinstance(psf, np.ndarray), "PSF must be a numpy array" - # PSF and data are the same size / shape self._original_shape = psf.shape - self._original_size = psf.size - self._apgd = None - self._gen = None - - super(APGD, self).__init__(psf, dtype, n_iter=max_iter, **kwargs) self._stop_crit = stop.MaxIter(max_iter) if rel_error is not None: self._stop_crit = self._stop_crit | stop.RelError(eps=rel_error) self._disp = disp - # Convolution operator + # Convolution (and optional downsampling) operator + if img_shape is not None: + + meas_shape = np.array(img_shape[:2]) + rec_shape = np.array(self._original_shape[1:3]) + assert np.all(meas_shape <= rec_shape), "Image shape must be smaller than PSF shape" + self.downsampling_factor = np.round(rec_shape / meas_shape).astype(int) + + # new PSF shape, must be integer multiple of image shape + new_shape = tuple(np.array(meas_shape) * self.downsampling_factor) + (psf.shape[-1],) + psf_re = resize(psf.copy(), shape=new_shape, interpolation=cv2.INTER_CUBIC) + + # combine operations + conv = RealFFTConvolve2D(psf_re, dtype=dtype) + ds = pycl.SubSample( + psf_re.shape, + slice(None), + slice(0, -1, self.downsampling_factor[0]), + slice(0, -1, self.downsampling_factor[1]), + slice(None), + ) + + self._H = ds * conv + + super(APGD, self).__init__(psf_re, dtype, n_iter=max_iter, **kwargs) + + else: + self.downsampling_factor = 1 + self._H = RealFFTConvolve2D(psf, dtype=dtype) + + super(APGD, self).__init__(psf, dtype, n_iter=max_iter, **kwargs) - self._H = RealFFTConvolve2D(self._psf, dtype=dtype) self._H.lipschitz(tol=lipschitz_tol, tight=lipschitz_tight) # initialize solvers which will be created when data is set @@ -192,9 +221,25 @@ def set_data(self, data): 3D (RGB). """ - super(APGD, self).set_data( - np.repeat(data, self._original_shape[-4], axis=0) - ) # we repeat the data for each depth to match the size of the PSF + + # super(APGD, self).set_data( + # np.repeat(data, self._original_shape[-4], axis=0) + # ) # we repeat the data for each depth to match the size of the PSF + + data = np.repeat(data, self._original_shape[-4], axis=0) # repeat for each depth + assert isinstance(data, np.ndarray) + assert len(data.shape) >= 3, "Data must be at least 3D: [..., width, height, channel]." + + assert np.all( + self._psf_shape[-3:-1] == (np.array(data.shape)[-3:-1] * self.downsampling_factor) + ), "PSF and data shape mismatch" + + if len(data.shape) == 3: + self._data = data[None, None, ...] + elif len(data.shape) == 4: + self._data = data[None, ...] + else: + self._data = data """ Set up problem """ # Cost function @@ -220,13 +265,15 @@ def reset(self): if self._initial_est is not None: self._image_est = self._initial_est else: - self._image_est = np.zeros(self._original_size, dtype=self._dtype) + self._image_est = np.zeros(np.prod(self._psf_shape), dtype=self._dtype) def _update(self, iter): res = next(self._apgd.steps()) self._image_est[:] = res["x"] def _form_image(self): - image = self._image_est.reshape(self._original_shape) + image = self._image_est.reshape(self._psf_shape) image[image < 0] = 0 + if np.any(self._psf_shape != self._original_shape): + image = resize(image, shape=self._original_shape) return image diff --git a/lensless/recon/recon.py b/lensless/recon/recon.py index 58200f2a..1124c289 100644 --- a/lensless/recon/recon.py +++ b/lensless/recon/recon.py @@ -10,7 +10,7 @@ ============== The core algorithmic component of ``LenslessPiCam`` is the abstract -class :py:class:`~lensless.ReconstructionAlgorithm`. The three reconstruction +class :py:class:`~lensless.ReconstructionAlgorithm`. The five reconstruction strategies available in ``LenslessPiCam`` derive from this class: - :py:class:`~lensless.GradientDescent`: projected gradient descent with a @@ -25,6 +25,14 @@ class :py:class:`~lensless.ReconstructionAlgorithm`. The three reconstruction long as it is compatible with Pycsou, namely derives from one of `DiffFunc `_ or `ProxFunc `_. +- :py:class:`~lensless.UnrolledFISTA`: unrolled FISTA with a non-negativity constraint. +- :py:class:`~lensless.UnrolledADMM`: unrolled ADMM with a non-negativity constraint and a total variation (TV) regularizer [1]_. + +Note that the unrolled algorithms derive from the abstract class +:py:class:`~lensless.TrainableReconstructionAlgorithm`, which itself derives from +:py:class:`~lensless.ReconstructionAlgorithm` while adding functionality +for training on batches and adding trainable pre- and post-processing +blocks. New reconstruction algorithms can be conveniently implemented by deriving from the abstract class and defining the following abstract @@ -154,6 +162,7 @@ class :py:class:`~lensless.ReconstructionAlgorithm`. The three reconstruction import pathlib as plib import matplotlib.pyplot as plt from lensless.utils.plot import plot_image +from lensless.utils.io import get_dtype from lensless.recon.rfft_convolve import RealFFTConvolve2D try: @@ -232,16 +241,7 @@ def __init__( self._psf_shape = np.array(self._psf.shape) # set dtype - if dtype is None: - if self.is_torch: - dtype = torch.float32 - else: - dtype = np.float32 - else: - if self.is_torch: - dtype = torch.float32 if dtype == "float32" else torch.float64 - else: - dtype = np.float32 if dtype == "float32" else np.float64 + dtype = get_dtype(dtype, self.is_torch) if self.is_torch: if dtype: @@ -491,7 +491,9 @@ def apply( if (plot or save) and disp_iter is not None: if ax is None: - ax = plot_image(self._get_numpy_data(self._image_est[0]), gamma=gamma) + img = self._form_image() + ax = plot_image(self._get_numpy_data(img[0]), gamma=gamma) + else: ax = None disp_iter = n_iter + 1 diff --git a/lensless/recon/trainable_recon.py b/lensless/recon/trainable_recon.py index c7129a3b..e554f6b0 100644 --- a/lensless/recon/trainable_recon.py +++ b/lensless/recon/trainable_recon.py @@ -5,7 +5,6 @@ # Yohann PERRON [yohann.perron@gmail.com] # ############################################################################# -import abc from lensless.recon.recon import ReconstructionAlgorithm try: @@ -24,7 +23,6 @@ class TrainableReconstructionAlgorithm(ReconstructionAlgorithm, torch.nn.Module) * ``_update``: updating state variables at each iterations. * ``reset``: reset state variables. * ``_form_image``: any pre-processing that needs to be done in order to view the image estimate, e.g. reshaping or clipping. - * ``batch_call``: method for performing iterative reconstruction on a batch of images. One advantage of deriving from this abstract class is that functionality for iterating, saving, and visualization is already implemented, namely in the diff --git a/lensless/utils/image.py b/lensless/utils/image.py index 7d2c65b3..19c977e2 100644 --- a/lensless/utils/image.py +++ b/lensless/utils/image.py @@ -1,5 +1,5 @@ # ############################################################################# -# image_utils.py +# image.py # ================= # Authors : # Eric BEZZAM [ebezzam@gmail.com] diff --git a/lensless/utils/io.py b/lensless/utils/io.py index 57c4f740..f502719a 100644 --- a/lensless/utils/io.py +++ b/lensless/utils/io.py @@ -1,3 +1,11 @@ +# ############################################################################# +# io.py +# ================= +# Authors : +# Eric BEZZAM [ebezzam@gmail.com] +# ############################################################################# + + import warnings from PIL import Image import cv2 @@ -6,7 +14,7 @@ from lensless.utils.plot import plot_image from lensless.hardware.constants import RPI_HQ_CAMERA_BLACK_LEVEL, RPI_HQ_CAMERA_CCM_MATRIX -from lensless.utils.image import bayer2rgb_cc, print_image_info, resize, rgb2gray +from lensless.utils.image import bayer2rgb_cc, print_image_info, resize, rgb2gray, get_max_val def load_image( @@ -22,6 +30,10 @@ def load_image( nbits_out=None, as_4d=False, downsample=None, + bg=None, + return_float=False, + shape=None, + dtype=None, ): """ Load image as numpy array. @@ -53,6 +65,15 @@ def load_image( height, width, color). downsample : int, optional Downsampling factor. Recommended for image reconstruction. + bg : array_like + Background level to subtract. + return_float : bool + Whether to return image as float array, or unsigned int. + shape : tuple, optional + Shape (H, W, C) to resize to. + dtype : str, optional + Data type of returned data. Default is to use that of input. + Returns ------- img : :py:class:`~numpy.ndarray` @@ -103,6 +124,8 @@ def load_image( if len(img.shape) == 3: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + original_dtype = img.dtype + if flip: img = np.flipud(img) img = np.fliplr(img) @@ -110,14 +133,39 @@ def load_image( if verbose: print_image_info(img) + if bg is not None: + + # if bg is float vector, turn into int-valued vector + if bg.max() <= 1: + bg = bg * get_max_val(img) + + img = img - bg + img = np.clip(img, a_min=0, a_max=img.max()) + if as_4d: if len(img.shape) == 3: img = img[np.newaxis, :, :, :] elif len(img.shape) == 2: img = img[np.newaxis, :, :, np.newaxis] - if downsample is not None: - img = resize(img, factor=1 / downsample) + if downsample is not None or shape is not None: + if downsample is not None: + factor = 1 / downsample + else: + factor = None + img = resize(img, factor=factor, shape=shape) + + if return_float: + if dtype is None: + dtype = np.float32 + assert dtype == np.float32 or dtype == np.float64 + img = img.astype(dtype) + img /= img.max() + + else: + if dtype is None: + dtype = original_dtype + img = img.astype(dtype) return img @@ -212,6 +260,7 @@ def load_psf( ) original_dtype = psf.dtype + max_val = get_max_val(psf) psf = np.array(psf, dtype=dtype) if use_3d: @@ -274,6 +323,7 @@ def load_psf( if return_float: # psf /= psf.max() psf /= np.linalg.norm(psf.ravel()) + bg /= max_val else: psf = psf.astype(original_dtype) @@ -379,21 +429,21 @@ def load_data( ) # load and process raw measurement - data = load_image(data_fp, flip=flip, bayer=bayer, blue_gain=blue_gain, red_gain=red_gain) - data = np.array(data, dtype=dtype) - - data -= bg - data = np.clip(data, a_min=0, a_max=data.max()) - - if len(data.shape) == 3: - data = data[np.newaxis, :, :, :] - elif len(data.shape) == 2: - data = data[np.newaxis, :, :, np.newaxis] + data = load_image( + data_fp, + flip=flip, + bayer=bayer, + blue_gain=blue_gain, + red_gain=red_gain, + bg=bg, + as_4d=True, + return_float=True, + shape=shape, + ) if data.shape != psf.shape: # in DiffuserCam dataset, images are already reshaped data = resize(data, shape=psf.shape) - data /= np.linalg.norm(data.ravel()) if data.shape[3] > 1 and psf.shape[3] == 1: warnings.warn( @@ -454,3 +504,58 @@ def save_image(img, fp, max_val=255): img = Image.fromarray(img) img.save(fp) + + +def get_dtype(dtype=None, is_torch=False): + """ + Get dtype for numpy or torch. + + Parameters + ---------- + dtype : str, optional + "float32" or "float64", Default is "float32". + is_torch : bool, optional + Whether to return torch dtype. + """ + if dtype is None: + dtype = "float32" + assert dtype == "float32" or dtype == "float64" + + if is_torch: + import torch + + if dtype is None: + if is_torch: + dtype = torch.float32 + else: + dtype = np.float32 + else: + if is_torch: + dtype = torch.float32 if dtype == "float32" else torch.float64 + else: + dtype = np.float32 if dtype == "float32" else np.float64 + + return dtype + + +def get_ctypes(dtype, is_torch): + if not is_torch: + if dtype == np.float32 or dtype == np.complex64: + return np.complex64, np.complex64 + elif dtype == np.float64 or dtype == np.complex128: + return np.complex128, np.complex128 + else: + raise ValueError("Unexpected dtype: ", dtype) + else: + import torch + + if dtype == np.float32 or dtype == np.complex64: + return torch.complex64, np.complex64 + elif dtype == np.float64 or dtype == np.complex128: + return torch.complex128, np.complex128 + elif dtype == torch.float32 or dtype == torch.complex64: + return torch.complex64, np.complex64 + elif dtype == torch.float64 or dtype == torch.complex128: + return torch.complex128, np.complex128 + else: + raise ValueError("Unexpected dtype: ", dtype) diff --git a/recon_requirements.txt b/recon_requirements.txt index 5d142936..4ebe4412 100644 --- a/recon_requirements.txt +++ b/recon_requirements.txt @@ -7,6 +7,6 @@ click>=8.0.1 waveprop>=0.0.3 # for simulation # Library for learning algorithm -torch >= 1.8.0 +torch >= 2.0.0 torchvision lpips \ No newline at end of file diff --git a/scripts/classify/train_celeba_vit.py b/scripts/classify/train_celeba_vit.py new file mode 100644 index 00000000..79a32e44 --- /dev/null +++ b/scripts/classify/train_celeba_vit.py @@ -0,0 +1,330 @@ +""" +Fine-tune ViT on CelebA dataset measured with lensless camera. +Original tutorial: https://huggingface.co/blog/fine-tune-vit + +First, set-up HuggingFace libraries: +``` +pip install datasets transformers +``` + +Raw measurement datasets can be download from SwitchDrive. +This will be done by the script if the dataset is not found. +``` +# 10K measurements (13.1 GB) +python scripts/classify/train_celeba_vit.py \ +data.measured=data/celeba_adafruit_random_2mm_20230720_10K + +# 1K measurements (1.2 GB) +python scripts/classify/train_celeba_vit.py \ +data.measured=data/celeba_adafruit_random_2mm_20230720_1K +``` + +Note that the CelebA dataset also needs to be available locally! +It can be download here: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html + +In order to classify on reconstructed outputs, the following +script needs to be run to create the dataset of reconstructed +images: +``` +# reconstruct with ADMM +python scripts/recon/dataset.py algo=admm \ +input.raw_data=path/to/raw/data +``` + +To classify on raw downsampled images, the same script can be +used, e.g. with the following command (`algo=null` for no reconstruction): +``` +python scripts/recon/dataset.py algo=null \ +input.raw_data=path/to/raw/data \ +preprocess.data_dim=[48,64] +``` + +Other hyperparameters for classification can be found in +`configs/train_celeba_classifier.yaml`. + +""" + +import warnings +from transformers import ViTImageProcessor, ViTForImageClassification +from transformers import TrainingArguments, Trainer, TrainerCallback +import numpy as np +import torch +import os +from hydra.utils import to_absolute_path +import glob +import hydra +import random +from datasets import load_metric +from PIL import Image +import pandas as pd +import time +import torchvision.transforms as transforms +import torchvision.datasets as dset +from datasets import Dataset +from copy import deepcopy +from torchvision.transforms import ( + CenterCrop, + Compose, + Normalize, + RandomHorizontalFlip, + RandomResizedCrop, + Resize, + ToTensor, +) + + +class CustomCallback(TrainerCallback): + def __init__(self, trainer) -> None: + super().__init__() + self._trainer = trainer + + def on_epoch_end(self, args, state, control, **kwargs): + if control.should_evaluate: + control_copy = deepcopy(control) + self._trainer.evaluate( + eval_dataset=self._trainer.train_dataset, metric_key_prefix="train" + ) + return control_copy + + def on_step_end(self, args, state, control, **kwargs): + if control.should_evaluate: + control_copy = deepcopy(control) + self._trainer.evaluate( + eval_dataset=self._trainer.train_dataset, metric_key_prefix="train" + ) + return control_copy + + def on_train_end(self, args, state, control, **kwargs): + if control.should_evaluate: + control_copy = deepcopy(control) + self._trainer.evaluate( + eval_dataset=self._trainer.train_dataset, metric_key_prefix="train" + ) + return control_copy + + +@hydra.main(version_base=None, config_path="../../configs", config_name="train_celeba_classifier") +def train_celeba_classifier(config): + + seed = config.seed + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + + # check how many measured files + measured_dataset = to_absolute_path(config.data.measured) + if not os.path.isdir(measured_dataset): + print(f"No dataset found at {measured_dataset}") + try: + from torchvision.datasets.utils import download_and_extract_archive + except ImportError: + exit() + msg = "Do you want to download the CelebA dataset measured with a random Adafruit LCD pattern (13.1 GB)?" + + # default to yes if no input is given + valid = input("%s (Y/n) " % msg).lower() != "n" + if valid: + url = "https://drive.switch.ch/index.php/s/9NNGCJs3DoBDGlY/download" + filename = "celeba_adafruit_random_2mm_20230720_10K.zip" + download_and_extract_archive( + url, os.path.dirname(measured_dataset), filename=filename, remove_finished=True + ) + measured_files = sorted(glob.glob(os.path.join(measured_dataset, "*.png"))) + print(f"Found {len(measured_files)} files in {measured_dataset}") + + if config.data.n_files is not None: + n_files = config.data.n_files + measured_files = measured_files[: config.data.n_files] + print(f"Using {len(measured_files)} files") + n_files = len(measured_files) + + # create dataset split + attr = config.data.attr + ds = dset.CelebA( + root=config.data.original, + split="all", + download=False, + transform=transforms.ToTensor(), + ) + label_idx = ds.attr_names.index(attr) + labels = ds.attr[:, label_idx][:n_files] + + # make dataset with measured data and corresponding labels + df = pd.DataFrame( + { + "labels": labels, + "image_file_path": measured_files, + } + ) + ds = Dataset.from_pandas(df) + ds = ds.class_encode_column("labels") + + # -- train / test split + test_size = config.data.test_size + ds = ds.train_test_split( + test_size=test_size, stratify_by_column="labels", seed=seed, shuffle=True + ) + + # prepare dataset + model_name_or_path = "google/vit-base-patch16-224-in21k" + processor = ViTImageProcessor.from_pretrained(model_name_or_path) + + # -- processors for train and val + image_mean, image_std = processor.image_mean, processor.image_std + size = processor.size["height"] + + normalize = Normalize(mean=image_mean, std=image_std) + # _train_transforms = Compose( + # [ + # # RandomResizedCrop( + # # size, + # # scale=(0.9, 1.0), + # # ratio=(0.9, 1.1), + # # ), + # Resize(size), + # CenterCrop(size), + # RandomHorizontalFlip(), + # ToTensor(), + # normalize, + # ] + # ) + _train_transforms = [] + if config.augmentation.random_resize_crop: + _train_transforms.append( + RandomResizedCrop( + size, + scale=(0.9, 1.0), + ratio=(0.9, 1.1), + ) + ) + _train_transforms.append( + Resize(size), + CenterCrop(size), + ) + if config.augmentation.horizontal_flip: + if config.data.raw: + warnings.warn("Horizontal flip is not supported for raw data, Skipping!") + else: + _train_transforms.append(RandomHorizontalFlip()) + _train_transforms.append( + ToTensor(), + normalize, + ) + _train_transforms = Compose(_train_transforms) + + _val_transforms = Compose( + [ + Resize(size), + CenterCrop(size), + ToTensor(), + normalize, + ] + ) + + def train_transforms(examples): + # Take a list of PIL images and turn them to pixel values + examples["pixel_values"] = [ + _train_transforms(Image.open(fp)) for fp in examples["image_file_path"] + ] + return examples + + def val_transforms(examples): + # Take a list of PIL images and turn them to pixel values + examples["pixel_values"] = [ + _val_transforms(Image.open(fp)) for fp in examples["image_file_path"] + ] + return examples + + # transform dataset + ds["train"].set_transform(train_transforms) + ds["test"].set_transform(val_transforms) + + # data collator + def collate_fn(batch): + return { + "pixel_values": torch.stack([x["pixel_values"] for x in batch]), + "labels": torch.tensor([x["labels"] for x in batch]), + } + + # evaluation metric + metric = load_metric("accuracy") + + def compute_metrics(p): + return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids) + + # load model + if config.train.prev is not None: + model_path = to_absolute_path(config.train.prev) + else: + model_path = model_name_or_path + + labels = ds["train"].features["labels"].names + model = ViTForImageClassification.from_pretrained( + model_path, + num_labels=len(labels), + id2label={str(i): c for i, c in enumerate(labels)}, + label2id={c: str(i) for i, c in enumerate(labels)}, + hidden_dropout_prob=config.train.dropout, + attention_probs_dropout_prob=config.train.dropout, + ) + + # configure training + output_dir = ( + config.data.output_dir + f"-{config.data.attr}" + os.path.basename(measured_dataset) + ) + + training_args = TrainingArguments( + output_dir=output_dir, + per_device_train_batch_size=config.train.batch_size, + evaluation_strategy="steps", + eval_steps=100, + save_steps=100, + num_train_epochs=config.train.n_epochs, + fp16=True, + logging_steps=10, + learning_rate=config.train.learning_rate, + save_total_limit=2, + remove_unused_columns=False, # important to keep False + push_to_hub=False, + report_to="tensorboard", + load_best_model_at_end=True, + ) + + trainer = Trainer( + model=model, + args=training_args, + data_collator=collate_fn, + compute_metrics=compute_metrics, + tokenizer=processor, + train_dataset=ds["train"], + eval_dataset=ds["test"], + ) + trainer.add_callback(CustomCallback(trainer)) # add accuracy on train set + + # train + hydra_output = os.getcwd() + print("Results saved to : ", hydra_output) + + start_time = time.time() + train_results = trainer.train() + trainer.save_model() + trainer.log_metrics("train", train_results.metrics) + trainer.save_metrics("train", train_results.metrics) + trainer.save_state() + + # evaluate + metrics = trainer.evaluate(ds["train"]) + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + metrics = trainer.evaluate(ds["test"]) + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # + hydra_output = os.getcwd() + print("Results saved to : ", hydra_output) + print(f"Training took {time.time() - start_time} seconds") + + +if __name__ == "__main__": + train_celeba_classifier() diff --git a/scripts/demo.py b/scripts/demo.py index 32b26e42..760b663a 100644 --- a/scripts/demo.py +++ b/scripts/demo.py @@ -18,7 +18,9 @@ @hydra.main(version_base=None, config_path="../configs", config_name="demo") def demo(config): - RPI_USERNAME, RPI_HOSTNAME = check_username_hostname(config.rpi.username, config.rpi.hostname) + check_username_hostname(config.rpi.username, config.rpi.hostname) + RPI_USERNAME = config.rpi.username + RPI_HOSTNAME = config.rpi.hostname display_fp = to_absolute_path(config.fp) if config.save: diff --git a/scripts/hardware/config_digicam.py b/scripts/hardware/config_digicam.py new file mode 100644 index 00000000..cd8cab86 --- /dev/null +++ b/scripts/hardware/config_digicam.py @@ -0,0 +1,101 @@ +import warnings +import hydra +from datetime import datetime +import numpy as np +from slm_controller import slm +from slm_controller.hardware import SLMParam, slm_devices +import matplotlib.pyplot as plt + +from lensless.hardware.slm import set_programmable_mask +from lensless.hardware.aperture import rect_aperture, circ_aperture +from lensless.hardware.utils import set_mask_sensor_distance + + +@hydra.main(version_base=None, config_path="../../configs", config_name="digicam") +def config_digicam(config): + + rpi_username = config.rpi.username + rpi_hostname = config.rpi.hostname + device = config.device + + shape = slm_devices[device][SLMParam.SLM_SHAPE] + if not slm_devices[device][SLMParam.MONOCHROME]: + shape = (3, *shape) + pixel_pitch = slm_devices[device][SLMParam.PIXEL_PITCH] + + # set mask to sensor distance + if config.z is not None and not config.virtual: + set_mask_sensor_distance(config.z, rpi_username, rpi_hostname) + + center = np.array(config.center) * pixel_pitch + + # create random pattern + pattern = None + if config.pattern.endswith(".npy"): + pattern = np.load(config.pattern) + elif config.pattern == "random": + rng = np.random.RandomState(1) + # pattern = rng.randint(low=0, high=np.iinfo(np.uint8).max, size=shape, dtype=np.uint8) + pattern = rng.uniform(low=config.min_val, high=1, size=shape) + pattern = (pattern * np.iinfo(np.uint8).max).astype(np.uint8) + + elif config.pattern == "rect": + rect_shape = config.rect_shape + apert_dim = rect_shape[0] * pixel_pitch[0], rect_shape[1] * pixel_pitch[1] + ap = rect_aperture( + apert_dim=apert_dim, + slm_shape=slm_devices[device][SLMParam.SLM_SHAPE], + pixel_pitch=pixel_pitch, + center=center, + ) + pattern = ap.values + elif config.pattern == "circ": + ap = circ_aperture( + radius=config.radius * pixel_pitch[0], + slm_shape=slm_devices[device][SLMParam.SLM_SHAPE], + pixel_pitch=pixel_pitch, + center=center, + ) + pattern = ap.values + else: + raise ValueError(f"Pattern {config.pattern} not supported") + + # save pattern + if not config.pattern.endswith(".npy") and config.save: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + pattern_fn = f"{device}_{config.pattern}_pattern_{timestamp}.npy" + np.save(pattern_fn, pattern) + print(f"Saved pattern to {pattern_fn}") + + print("Pattern shape : ", pattern.shape) + print("Pattern dtype : ", pattern.dtype) + print("Pattern min : ", pattern.min()) + print("Pattern max : ", pattern.max()) + + # apply aperture + if config.aperture is not None: + + aperture = np.zeros(shape, dtype=np.uint8) + top_left = np.array(config.aperture.center) - np.array(config.aperture.shape) // 2 + bottom_right = top_left + np.array(config.aperture.shape) + aperture[:, top_left[0] : bottom_right[0], top_left[1] : bottom_right[1]] = 1 + pattern = pattern * aperture + + assert pattern is not None + + n_nonzero = np.count_nonzero(pattern) + print(f"Nonzero pixels: {n_nonzero}") + + if not config.virtual: + set_programmable_mask(pattern, device, rpi_username, rpi_hostname) + + # preview mask + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + s = slm.create(device) + s._show_preview(pattern) + plt.savefig("preview.png") + + +if __name__ == "__main__": + config_digicam() diff --git a/scripts/hardware/digicam_measure_psfs.py b/scripts/hardware/digicam_measure_psfs.py new file mode 100644 index 00000000..901d24cb --- /dev/null +++ b/scripts/hardware/digicam_measure_psfs.py @@ -0,0 +1,60 @@ +import numpy as np +from lensless.hardware.utils import set_mask_sensor_distance +import hydra +import os +from datetime import datetime +from PIL import Image + +SATURATION_THRESHOLD = 0.01 + + +@hydra.main(version_base=None, config_path="../../configs", config_name="digicam") +def config_digicam(config): + + rpi_username = config.rpi.username + rpi_hostname = config.rpi.hostname + + mask_sensor_distances = np.arange(9) * 0.1 + exposure_time = 5 + + timestamp = datetime.now().strftime("%Y%m%d") + + for i in range(len(mask_sensor_distances)): + + print(f"Mask sensor distance: {mask_sensor_distances[i]}mm") + mask_sensor_distance = mask_sensor_distances[i] + + # set the mask sensor distance + set_mask_sensor_distance(mask_sensor_distance, rpi_username, rpi_hostname) + + good_exposure = False + while not good_exposure: + + # measure PSF + output_folder = f"adafruit_psf_{mask_sensor_distance}mm__{timestamp}" + os.system( + f"python scripts/remote_capture.py -cn capture_bayer output={output_folder} rpi.username={rpi_username} rpi.hostname={rpi_hostname} capture.exp={exposure_time}" + ) + + # check for saturation + OUTPUT_FP = os.path.join(output_folder, "raw_data.png") + # -- load picture to check for saturation + img = np.array(Image.open(OUTPUT_FP)) + ratio = np.sum(img == 4095) / np.prod(img.shape) + print(f"Saturation ratio: {ratio}") + if ratio > SATURATION_THRESHOLD or ratio == 0: + + if ratio == 0: + print("Need to increase exposure time.") + else: + print("Need to decrease exposure time.") + + # enter new exposure time from keyboard + exposure_time = float(input("Enter new exposure time: ")) + + else: + good_exposure = True + + +if __name__ == "__main__": + config_digicam() diff --git a/scripts/measure/remote_capture.py b/scripts/measure/remote_capture.py index 92f2033e..66210a86 100644 --- a/scripts/measure/remote_capture.py +++ b/scripts/measure/remote_capture.py @@ -32,7 +32,9 @@ def liveview(config): rgb = config.capture.rgb gray = config.capture.gray - username, hostname = check_username_hostname(config.rpi.username, config.rpi.hostname) + check_username_hostname(config.rpi.username, config.rpi.hostname) + username = config.rpi.username + hostname = config.rpi.hostname legacy = config.capture.legacy nbits_out = config.capture.nbits_out fn = config.capture.raw_data_fn diff --git a/scripts/measure/remote_display.py b/scripts/measure/remote_display.py index f9ab3ed2..1be931a3 100644 --- a/scripts/measure/remote_display.py +++ b/scripts/measure/remote_display.py @@ -35,12 +35,15 @@ @hydra.main(version_base=None, config_path="../../configs", config_name="demo") def remote_display(config): - username, hostname = check_username_hostname(config.rpi.username, config.rpi.hostname) + check_username_hostname(config.rpi.username, config.rpi.hostname) + username = config.rpi.username + hostname = config.rpi.hostname fp = config.fp shape = np.array(config.display.screen_res) psf = config.display.psf black = config.display.black + white = config.display.white if psf: point_source = np.zeros(tuple(shape) + (3,)) @@ -58,12 +61,18 @@ def remote_display(config): im = Image.fromarray(point_source.astype("uint8"), "RGB") im.save(fp) + elif white: + point_source = np.ones(tuple(shape) + (3,)) * 255 + fp = "tmp_display.png" + im = Image.fromarray(point_source.astype("uint8"), "RGB") + im.save(fp) + """ processing on remote machine, less issues with copying """ # copy picture to Raspberry Pi print("\nCopying over picture...") display(fp=fp, rpi_username=username, rpi_hostname=hostname, **config.display) - if psf or black: + if psf or black or white: os.remove(fp) diff --git a/scripts/recon/admm.py b/scripts/recon/admm.py index 17a88461..3ba3de1f 100644 --- a/scripts/recon/admm.py +++ b/scripts/recon/admm.py @@ -41,6 +41,7 @@ def admm(config): shape=config["preprocess"]["shape"], torch=config.torch, torch_device=config.torch_device, + bg_pix=config.preprocess.bg_pix, ) disp = config["display"]["disp"] diff --git a/scripts/recon/apgd_pycsou.py b/scripts/recon/apgd_pycsou.py index 878b378f..0bf236d0 100644 --- a/scripts/recon/apgd_pycsou.py +++ b/scripts/recon/apgd_pycsou.py @@ -17,7 +17,7 @@ import time import matplotlib.pyplot as plt from lensless.utils.io import load_data -from lensless import APGD +from lensless.recon.apgd import APGD import os import pathlib as plib @@ -28,7 +28,7 @@ log = logging.getLogger(__name__) -@hydra.main(version_base=None, config_path="../../configs", config_name="defaults_recon") +@hydra.main(version_base=None, config_path="../../configs", config_name="apgd_l1") def apgd( config, ): diff --git a/scripts/recon/dataset.py b/scripts/recon/dataset.py new file mode 100644 index 00000000..4c4192c5 --- /dev/null +++ b/scripts/recon/dataset.py @@ -0,0 +1,202 @@ +""" +Apply ADMM reconstruction to folder. + +``` +python scripts/recon/dataset.py +``` + +To run APGD, use the following command: +``` +python scripts/recon/dataset.py algo=apgd +``` + +To just copy resized raw data, use the following command: +``` +python scripts/recon/dataset.py algo=null preprocess.data_dim=[48,64] +``` + +""" + +import hydra +from hydra.utils import to_absolute_path +import os +import time +import numpy as np +from lensless.utils.io import load_psf, load_image, save_image +from lensless import ADMM +import torch +import glob +from tqdm import tqdm +from lensless.recon.apgd import APGD +from joblib import Parallel, delayed + + +@hydra.main(version_base=None, config_path="../../configs", config_name="recon_dataset") +def admm_dataset(config): + + algo = config.algo + + # get raw data file paths + dataset = to_absolute_path(config.input.raw_data) + if not os.path.isdir(dataset): + print(f"No dataset found at {dataset}") + try: + from torchvision.datasets.utils import download_and_extract_archive + except ImportError: + exit() + msg = "Do you want to download the sample CelebA dataset measured with a random Adafruit LCD pattern (1.2 GB)?" + + # default to yes if no input is given + valid = input("%s (Y/n) " % msg).lower() != "n" + if valid: + url = "https://drive.switch.ch/index.php/s/m89D1tFEfktQueS/download" + filename = "celeba_adafruit_random_2mm_20230720_1K.zip" + download_and_extract_archive( + url, os.path.dirname(dataset), filename=filename, remove_finished=True + ) + data_fps = sorted(glob.glob(os.path.join(dataset, "*.png"))) + if config.n_files is not None: + data_fps = data_fps[: config.n_files] + n_files = len(data_fps) + + # load PSF + psf_fp = to_absolute_path(config.input.psf) + flip = config.preprocess.flip + dtype = config.input.dtype + print("\nPSF:") + psf, bg = load_psf( + psf_fp, + verbose=True, + downsample=config.preprocess.downsample, + return_bg=True, + flip=flip, + dtype=dtype, + ) + print(f"Downsampled PSF shape: {psf.shape}") + + data_dim = None + if config.preprocess.data_dim is not None: + data_dim = tuple(config.preprocess.data_dim) + (psf.shape[-1],) + else: + data_dim = psf.shape + + # -- create output folder + output_folder = to_absolute_path(config.output_folder) + if algo == "apgd": + output_folder = output_folder + f"_apgd{config.apgd.max_iter}" + elif algo == "admm": + output_folder = output_folder + f"_admm{config.admm.n_iter}" + else: + output_folder = output_folder + "_raw" + output_folder = output_folder + f"_{data_dim[-3]}x{data_dim[-2]}" + os.makedirs(output_folder, exist_ok=True) + + # -- apply reconstruction + if algo == "apgd": + + start_time = time.time() + + def recover(i): + + # reconstruction object + recon = APGD(psf=psf, **config.apgd) + + data_fp = data_fps[i] + + # load data + data = load_image( + data_fp, flip=flip, bg=bg, as_4d=True, return_float=True, shape=data_dim + ) + data = data[0] # first depth + + # apply reconstruction + recon.set_data(data) + img = recon.apply( + disp_iter=config.display.disp, + gamma=config.display.gamma, + plot=config.display.plot, + ) + + # -- extract region of interest and save + if config.roi is not None: + roi = config.roi + img = img[roi[0] : roi[2], roi[1] : roi[3]] + + bn = os.path.basename(data_fp) + output_fp = os.path.join(output_folder, bn) + save_image(img, output_fp) + + n_jobs = config.apgd.n_jobs + if n_jobs > 1: + Parallel(n_jobs=n_jobs)(delayed(recover)(i) for i in range(n_files)) + else: + for i in tqdm(range(n_files)): + recover(i) + + else: + + if config.torch: + torch_dtype = torch.float32 + torch_device = config.torch_device + psf = torch.from_numpy(psf).type(torch_dtype).to(torch_device) + + # create reconstruction object + recon = None + if config.algo == "admm": + recon = ADMM(psf, **config.admm) + + # loop over files and apply reconstruction + start_time = time.time() + + for i in tqdm(range(n_files)): + data_fp = data_fps[i] + + # load data + data = load_image( + data_fp, flip=flip, bg=bg, as_4d=True, return_float=True, shape=data_dim + ) + + if config.torch: + data = torch.from_numpy(data).type(torch_dtype).to(torch_device) + + if recon is not None: + + # set data + recon.set_data(data) + + # apply reconstruction + res = recon.apply( + n_iter=config.admm.n_iter, + disp_iter=config.display.disp, + gamma=config.display.gamma, + plot=config.display.plot, + ) + + else: + + # copy resized raw data + res = data + + # save reconstruction as PNG + # -- take first depth + if config.torch: + img = res[0].cpu().numpy() + else: + img = res[0] + + # -- extract region of interest + if config.roi is not None: + img = img[config.roi[0] : config.roi[2], config.roi[1] : config.roi[3]] + + bn = os.path.basename(data_fp) + output_fp = os.path.join(output_folder, bn) + save_image(img, output_fp) + + print(f"Processing time : {time.time() - start_time} s") + # time per file + print(f"Time per file : {(time.time() - start_time) / n_files} s") + print("Files saved to: ", output_folder) + + +if __name__ == "__main__": + admm_dataset() diff --git a/scripts/sim/dataset.py b/scripts/sim/dataset.py index 2c08ba71..263d01f2 100644 --- a/scripts/sim/dataset.py +++ b/scripts/sim/dataset.py @@ -32,7 +32,7 @@ def simulate(config): if not os.path.isdir(dataset): print(f"No dataset found at {dataset}") try: - from torchvision.datasets.utils import download_and_extract_archive, download_url + from torchvision.datasets.utils import download_and_extract_archive except ImportError: exit() msg = "Do you want to download the sample CelebA dataset (764KB)?" diff --git a/scripts/sim/digicam_psf.py b/scripts/sim/digicam_psf.py new file mode 100644 index 00000000..d0e0636b --- /dev/null +++ b/scripts/sim/digicam_psf.py @@ -0,0 +1,154 @@ +import numpy as np +import os +import time +import hydra +import torch +from hydra.utils import to_absolute_path +import matplotlib.pyplot as plt +from slm_controller import slm +from lensless.utils.io import save_image, get_dtype, load_psf +from lensless.utils.plot import plot_image +from lensless.hardware.sensor import VirtualSensor +from lensless.hardware.slm import get_programmable_mask, get_intensity_psf +from waveprop.devices import slm_dict +from PIL import Image + + +@hydra.main(version_base=None, config_path="../../configs", config_name="sim_digicam_psf") +def digicam_psf(config): + + output_folder = os.getcwd() + + fp = to_absolute_path(config.digicam.pattern) + bn = os.path.basename(fp).split(".")[0] + + # digicam config + ap_center = np.array(config.digicam.ap_center) + ap_shape = np.array(config.digicam.ap_shape) + rotate_angle = config.digicam.rotate + slm_param = slm_dict[config.digicam.slm] + sensor = VirtualSensor.from_name(config.digicam.sensor) + + # simulation parameters + scene2mask = config.sim.scene2mask + mask2sensor = config.sim.mask2sensor + + torch_device = config.torch_device + dtype = get_dtype(config.dtype, config.use_torch) + + """ + Load pattern + """ + pattern = np.load(fp) + + # -- apply aperture + aperture = np.zeros(pattern.shape, dtype=np.uint8) + top_left = np.array(ap_center) - np.array(ap_shape) // 2 + bottom_right = top_left + np.array(ap_shape) + aperture[:, top_left[0] : bottom_right[0], top_left[1] : bottom_right[1]] = 1 + pattern = pattern * aperture + + # -- extract aperture region + idx_1 = ap_center[0] - ap_shape[0] // 2 + idx_2 = ap_center[1] - ap_shape[1] // 2 + + pattern_sub = pattern[ + :, + idx_1 : idx_1 + ap_shape[0], + idx_2 : idx_2 + ap_shape[1], + ] + print("Controllable region shape: ", pattern_sub.shape) + print("Total number of pixels: ", np.prod(pattern_sub.shape)) + + # -- plot full + s = slm.create(config.digicam.slm) + s.set_preview(True) + s.imshow(pattern) + plt.savefig(os.path.join(output_folder, "pattern.png")) + + # -- plot sub pattern + plt.imshow(pattern_sub.transpose(1, 2, 0)) + plt.savefig(os.path.join(output_folder, "pattern_sub.png")) + + """ + Simulate PSF + """ + start_time = time.time() + slm_vals = pattern_sub / 255.0 + + if config.digicam.slm == "adafruit": + # flatten color channel along rows + slm_vals = slm_vals.reshape((-1, slm_vals.shape[-1]), order="F") + + if config.use_torch: + slm_vals = torch.from_numpy(slm_vals).to(device=torch_device, dtype=dtype) + else: + slm_vals = slm_vals.astype(dtype) + + mask = get_programmable_mask( + vals=slm_vals, + sensor=sensor, + slm_param=slm_param, + rotate=rotate_angle, + flipud=config.sim.flipud, + ) + + # -- plot mask + if config.use_torch: + mask_np = mask.cpu().detach().numpy() + else: + mask_np = mask.copy() + mask_np = np.transpose(mask_np, (1, 2, 0)) + plt.imshow(mask_np) + plt.savefig(os.path.join(output_folder, "mask.png")) + + # -- propagate to sensor + psf_in = get_intensity_psf( + mask=mask, + sensor=sensor, + waveprop=config.sim.waveprop, + scene2mask=scene2mask, + mask2sensor=mask2sensor, + ) + + # -- plot PSF + if config.use_torch: + psf_in_np = psf_in.cpu().detach().numpy() + else: + psf_in_np = psf_in.copy() + psf_in_np = np.transpose(psf_in_np, (1, 2, 0)) + + # plot + psf_meas = None + if config.digicam.psf is not None: + fp_psf = to_absolute_path(config.digicam.psf) + if os.path.exists(fp_psf): + psf_meas = load_psf(fp_psf) + else: + print("Could not load PSF image from: ", fp_psf) + + fp = os.path.join(output_folder, "psf_plot.png") + if psf_meas is not None: + _, ax = plt.subplots(1, 2) + ax[0].imshow(psf_in_np) + ax[0].set_title("Simulated") + plot_image(psf_meas, gamma=config.digicam.gamma, normalize=True, ax=ax[1]) + # ax[1].imshow(psf_meas) + ax[1].set_title("Measured") + plt.savefig(fp) + else: + plt.imshow(psf_in_np) + plt.savefig(fp) + + # save PSF as png + fp = os.path.join(output_folder, f"{bn}_SIM_psf.png") + save_image(psf_in_np, fp) + + proc_time = time.time() - start_time + print(f"\nProcessing time: {proc_time:.2f} seconds") + + print(f"\nFiles saved to : {output_folder}") + + +if __name__ == "__main__": + digicam_psf()