Skip to content

Commit

Permalink
Merge pull request py4dstem#550 from py4dstem/phase_contrast
Browse files Browse the repository at this point in the history
Bad blood with phase retrieval? We're not out the woods yet, but shake it off
  • Loading branch information
smribet authored Nov 6, 2023
2 parents 02c419a + dd09924 commit ff49b03
Show file tree
Hide file tree
Showing 14 changed files with 1,139 additions and 241 deletions.
287 changes: 271 additions & 16 deletions py4DSTEM/process/phase/iterative_base_class.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion py4DSTEM/process/phase/iterative_dpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,7 @@ def reconstruct(
anti_gridding=anti_gridding,
)

self.error_iterations.append(self.error.item())
if store_iterations:
self.object_phase_iterations.append(
asnumpy(
Expand All @@ -807,7 +808,6 @@ def reconstruct(
].copy()
)
)
self.error_iterations.append(self.error.item())

if self._step_size < stopping_criterion:
if self._verbose:
Expand Down
151 changes: 147 additions & 4 deletions py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,17 @@ class MixedstateMultislicePtychographicReconstruction(PtychographicReconstructio
initial_scan_positions: np.ndarray, optional
Probe positions in Å for each diffraction intensity
If None, initialized to a grid scan
theta_x: float
x tilt of propagator (in degrees)
theta_y: float
y tilt of propagator (in degrees)
middle_focus: bool
if True, adds half the sample thickness to the defocus
object_type: str, optional
The object can be reconstructed as a real potential ('potential') or a complex
object ('complex')
positions_mask: np.ndarray, optional
Boolean real space mask to select positions in datacube to skip for reconstruction
verbose: bool, optional
If True, class methods will inherit this and print additional information
device: str, optional
Expand Down Expand Up @@ -114,7 +122,11 @@ def __init__(
initial_object_guess: np.ndarray = None,
initial_probe_guess: np.ndarray = None,
initial_scan_positions: np.ndarray = None,
theta_x: float = 0,
theta_y: float = 0,
middle_focus: bool = False,
object_type: str = "complex",
positions_mask: np.ndarray = None,
verbose: bool = True,
device: str = "cpu",
name: str = "multi-slice_ptychographic_reconstruction",
Expand Down Expand Up @@ -162,6 +174,25 @@ def __init__(
if (key not in polar_symbols) and (key not in polar_aliases.keys()):
raise ValueError("{} not a recognized parameter".format(key))

if np.isscalar(slice_thicknesses):
mean_slice_thickness = slice_thicknesses
else:
mean_slice_thickness = np.mean(slice_thicknesses)

if middle_focus:
if "defocus" in kwargs:
kwargs["defocus"] += mean_slice_thickness * num_slices / 2
elif "C10" in kwargs:
kwargs["C10"] -= mean_slice_thickness * num_slices / 2
elif polar_parameters is not None and "defocus" in polar_parameters:
polar_parameters["defocus"] = (
polar_parameters["defocus"] + mean_slice_thickness * num_slices / 2
)
elif polar_parameters is not None and "C10" in polar_parameters:
polar_parameters["C10"] = (
polar_parameters["C10"] - mean_slice_thickness * num_slices / 2
)

self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols)))

if polar_parameters is None:
Expand All @@ -186,6 +217,13 @@ def __init__(
f"object_type must be either 'potential' or 'complex', not {object_type}"
)

if positions_mask is not None and positions_mask.dtype != "bool":
warnings.warn(
("`positions_mask` converted to `bool` array"),
UserWarning,
)
positions_mask = np.asarray(positions_mask, dtype="bool")

self.set_save_defaults()

# Data
Expand All @@ -201,6 +239,7 @@ def __init__(
self._semiangle_cutoff_pixels = semiangle_cutoff_pixels
self._rolloff = rolloff
self._object_type = object_type
self._positions_mask = positions_mask
self._object_padding_px = object_padding_px
self._verbose = verbose
self._device = device
Expand All @@ -210,13 +249,17 @@ def __init__(
self._num_probes = num_probes
self._num_slices = num_slices
self._slice_thicknesses = slice_thicknesses
self._theta_x = theta_x
self._theta_y = theta_y

def _precompute_propagator_arrays(
self,
gpts: Tuple[int, int],
sampling: Tuple[float, float],
energy: float,
slice_thicknesses: Sequence[float],
theta_x: float,
theta_y: float,
):
"""
Precomputes propagator arrays complex wave-function will be convolved by,
Expand All @@ -232,6 +275,10 @@ def _precompute_propagator_arrays(
The electron energy of the wave functions in eV
slice_thicknesses: Sequence[float]
Array of slice thicknesses in A
theta_x: float
x tilt of propagator (in degrees)
theta_y: float
y tilt of propagator (in degrees)
Returns
-------
Expand All @@ -251,13 +298,23 @@ def _precompute_propagator_arrays(
propagators = xp.empty(
(num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64
)

theta_x = np.deg2rad(theta_x)
theta_y = np.deg2rad(theta_y)

for i, dz in enumerate(slice_thicknesses):
propagators[i] = xp.exp(
1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz)
)
propagators[i] *= xp.exp(
1.0j * (-(ky**2)[None] * np.pi * wavelength * dz)
)
propagators[i] *= xp.exp(
1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x))
)
propagators[i] *= xp.exp(
1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y))
)

return propagators

Expand Down Expand Up @@ -445,7 +502,11 @@ def preprocess(
self._amplitudes,
self._mean_diffraction_intensity,
) = self._normalize_diffraction_intensities(
self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns
self._intensities,
self._com_fitted_x,
self._com_fitted_y,
crop_patterns,
self._positions_mask,
)

# explicitly delete namespace
Expand All @@ -454,7 +515,7 @@ def preprocess(
del self._intensities

self._positions_px = self._calculate_scan_positions_in_pixels(
self._scan_positions
self._scan_positions, self._positions_mask
)

# handle semiangle specified in pixels
Expand Down Expand Up @@ -597,6 +658,8 @@ def preprocess(
self.sampling,
self._energy,
self._slice_thicknesses,
self._theta_x,
self._theta_y,
)

# overlaps
Expand Down Expand Up @@ -3060,6 +3123,7 @@ def show_slices(
common_color_scale: bool = True,
padding: int = 0,
num_cols: int = 3,
show_fft: bool = False,
**kwargs,
):
"""
Expand All @@ -3075,12 +3139,20 @@ def show_slices(
Padding to leave uncropped
num_cols: int, optional
Number of GridSpec columns
show_fft: bool, optional
if True, plots fft of object slices
"""

if ms_object is None:
ms_object = self._object

rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding)
if show_fft:
rotated_object = np.abs(
np.fft.fftshift(
np.fft.fft2(rotated_object, axes=(-2, -1)), axes=(-2, -1)
)
)
rotated_shape = rotated_object.shape

if np.iscomplexobj(rotated_object):
Expand All @@ -3098,8 +3170,21 @@ def show_slices(

axsize = kwargs.pop("axsize", (3, 3))
cmap = kwargs.pop("cmap", "magma")
vmin = np.min(rotated_object) if common_color_scale else None
vmax = np.max(rotated_object) if common_color_scale else None

if common_color_scale:
vals = np.sort(rotated_object.ravel())
ind_vmin = np.round((vals.shape[0] - 1) * 0.02).astype("int")
ind_vmax = np.round((vals.shape[0] - 1) * 0.98).astype("int")
ind_vmin = np.max([0, ind_vmin])
ind_vmax = np.min([len(vals) - 1, ind_vmax])
vmin = vals[ind_vmin]
vmax = vals[ind_vmax]
if vmax == vmin:
vmin = vals[0]
vmax = vals[-1]
else:
vmax = None
vmin = None
vmin = kwargs.pop("vmin", vmin)
vmax = kwargs.pop("vmax", vmax)

Expand Down Expand Up @@ -3509,3 +3594,61 @@ def _return_object_fft(

obj = self._crop_rotate_object_fov(np.sum(obj, axis=0))
return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj))))

def _return_self_consistency_errors(
self,
max_batch_size=None,
):
"""Compute the self-consistency errors for each probe position"""

xp = self._xp
asnumpy = self._asnumpy

# Batch-size
if max_batch_size is None:
max_batch_size = self._num_diffraction_patterns

# Re-initialize fractional positions and vector patches
errors = np.array([])
positions_px = self._positions_px.copy()

for start, end in generate_batches(
self._num_diffraction_patterns, max_batch=max_batch_size
):
# batch indices
self._positions_px = positions_px[start:end]
self._positions_px_fractional = self._positions_px - xp.round(
self._positions_px
)
(
self._vectorized_patch_indices_row,
self._vectorized_patch_indices_col,
) = self._extract_vectorized_patch_indices()
amplitudes = self._amplitudes[start:end]

# Overlaps
_, _, overlap = self._overlap_projection(self._object, self._probe)
fourier_overlap = xp.fft.fft2(overlap)
intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1))

# Normalized mean-squared errors
batch_errors = xp.sum(
xp.abs(amplitudes - intensity_norm) ** 2, axis=(-2, -1)
)
errors = np.hstack((errors, batch_errors))

self._positions_px = positions_px.copy()
errors /= self._mean_diffraction_intensity

return asnumpy(errors)

def _return_projected_cropped_potential(
self,
):
"""Utility function to accommodate multiple classes"""
if self._object_type == "complex":
projected_cropped_potential = np.angle(self.object_cropped).sum(0)
else:
projected_cropped_potential = self.object_cropped.sum(0)

return projected_cropped_potential
Loading

0 comments on commit ff49b03

Please sign in to comment.