Skip to content

Commit

Permalink
Merge pull request py4dstem#551 from bsavitzky/morestrain
Browse files Browse the repository at this point in the history
I've got a blank stress tensor, baby, and I'll write your strain
  • Loading branch information
bsavitzky authored Nov 6, 2023
2 parents ff49b03 + e23878c commit 2bd2fa6
Show file tree
Hide file tree
Showing 7 changed files with 1,147 additions and 315 deletions.
160 changes: 101 additions & 59 deletions py4DSTEM/braggvectors/braggvector_methods.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# BraggVectors methods

import numpy as np
from scipy.ndimage import gaussian_filter
from warnings import warn
import inspect
from warnings import warn

from emdfile import Array, Metadata, tqdmnd, _read_metadata
import matplotlib.pyplot as plt
import numpy as np
from emdfile import Array, Metadata, _read_metadata, tqdmnd
from py4DSTEM import show
from py4DSTEM.datacube import VirtualImage
from scipy.ndimage import gaussian_filter


class BraggVectorMethods:
Expand Down Expand Up @@ -518,6 +520,7 @@ def fit_origin(
mask_check_data=True,
plot=True,
plot_range=None,
cmap="RdBu_r",
returncalc=True,
**kwargs,
):
Expand All @@ -537,6 +540,7 @@ def fit_origin(
mask_check_data (bool): Get mask from origin measurements equal to zero. (TODO - replace)
plot (bool, optional): plot results
plot_range (float): min and max color range for plot (pixels)
cmap (colormap): plotting colormap
Returns:
(variable): Return value depends on returnfitp. If ``returnfitp==False``
Expand Down Expand Up @@ -567,75 +571,98 @@ def fit_origin(
robust_thresh=robust_thresh,
)

# try to add to calibration
# try to add update calibration metadata
try:
self.calibration.set_origin([qx0_fit, qy0_fit])
self.calibration.set_origin((qx0_fit, qy0_fit))
self.setcal()
except AttributeError:
warn(
"No calibration found on this datacube - fit values are not being stored"
)
pass
if plot:
from py4DSTEM.visualize import show_image_grid

if mask is None:
qx0_meas, qy0_meas = q_meas
qx0_res_plot = qx0_residuals
qy0_res_plot = qy0_residuals
else:
qx0_meas = np.ma.masked_array(q_meas[0], mask=np.logical_not(mask))
qy0_meas = np.ma.masked_array(q_meas[1], mask=np.logical_not(mask))
qx0_res_plot = np.ma.masked_array(
qx0_residuals, mask=np.logical_not(mask)
)
qy0_res_plot = np.ma.masked_array(
qy0_residuals, mask=np.logical_not(mask)
)
qx0_mean = np.mean(qx0_fit)
qy0_mean = np.mean(qy0_fit)

if plot_range is None:
plot_range = 2 * np.max(qx0_fit - qx0_mean)

cmap = kwargs.get("cmap", "RdBu_r")
kwargs.pop("cmap", None)
axsize = kwargs.get("axsize", (6, 2))
kwargs.pop("axsize", None)

show_image_grid(
lambda i: [
qx0_meas - qx0_mean,
qx0_fit - qx0_mean,
qx0_res_plot,
qy0_meas - qy0_mean,
qy0_fit - qy0_mean,
qy0_res_plot,
][i],
H=2,
W=3,
# show
if plot:
self.show_origin_fit(
q_meas[0],
q_meas[1],
qx0_fit,
qy0_fit,
qx0_residuals,
qy0_residuals,
mask=mask,
plot_range=plot_range,
cmap=cmap,
axsize=axsize,
title=[
"measured origin, x",
"fitorigin, x",
"residuals, x",
"measured origin, y",
"fitorigin, y",
"residuals, y",
],
vmin=-1 * plot_range,
vmax=1 * plot_range,
intensity_range="absolute",
**kwargs,
)

# update calibration metadata
self.calibration.set_origin((qx0_fit, qy0_fit))
self.setcal()

# return
if returncalc:
return qx0_fit, qy0_fit, qx0_residuals, qy0_residuals

def show_origin_fit(
self,
qx0_meas,
qy0_meas,
qx0_fit,
qy0_fit,
qx0_residuals,
qy0_residuals,
mask=None,
plot_range=None,
cmap="RdBu_r",
**kwargs,
):
# apply mask
if mask is not None:
qx0_meas = np.ma.masked_array(qx0_meas, mask=np.logical_not(mask))
qy0_meas = np.ma.masked_array(qy0_meas, mask=np.logical_not(mask))
qx0_residuals = np.ma.masked_array(qx0_residuals, mask=np.logical_not(mask))
qy0_residuals = np.ma.masked_array(qy0_residuals, mask=np.logical_not(mask))
qx0_mean = np.mean(qx0_fit)
qy0_mean = np.mean(qy0_fit)

# set range
if plot_range is None:
plot_range = max(
(
1.5 * np.max(np.abs(qx0_fit - qx0_mean)),
1.5 * np.max(np.abs(qy0_fit - qy0_mean)),
)
)

# set figsize
imsize_ratio = np.sqrt(qx0_meas.shape[1] / qx0_meas.shape[0])
axsize = (3 * imsize_ratio, 3 / imsize_ratio)
axsize = kwargs.pop("axsize", axsize)

# plot
fig, ax = show(
[
[qx0_meas - qx0_mean, qx0_fit - qx0_mean, qx0_residuals],
[qy0_meas - qy0_mean, qy0_fit - qy0_mean, qy0_residuals],
],
cmap=cmap,
axsize=axsize,
title=[
"measured origin, x",
"fitorigin, x",
"residuals, x",
"measured origin, y",
"fitorigin, y",
"residuals, y",
],
vmin=-1 * plot_range,
vmax=1 * plot_range,
intensity_range="absolute",
show_cbar=True,
returnfig=True,
**kwargs,
)
plt.tight_layout()

return

def fit_p_ellipse(
self, bvm, center, fitradii, mask=None, returncalc=False, **kwargs
):
Expand Down Expand Up @@ -771,6 +798,21 @@ def mask_in_R(self, mask, update_inplace=False, returncalc=True):
else:
return

def to_strainmap(self, name: str = None):
"""
Generate a StrainMap object from the BraggVectors
equivalent to py4DSTEM.StrainMap(braggvectors=braggvectors)
Args:
name (str, optional): The name of the strainmap. Defaults to None which reverts to default name 'strainmap'.
Returns:
py4DSTEM.StrainMap: A py4DSTEM StrainMap object generated from the BraggVectors
"""
from py4DSTEM.process.strain import StrainMap

return StrainMap(self, name) if name else StrainMap(self)


######### END BraggVectorMethods CLASS ########

Expand Down
12 changes: 6 additions & 6 deletions py4DSTEM/braggvectors/braggvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def setcal(
if pixel is None:
pixel = False if c.get_Q_pixel_size() == 1 else True
if rotate is None:
rotate = False if c.get_QR_rotflip() is None else True
rotate = False if c.get_QR_rotation() is None else True

# validate requested state
if center:
Expand All @@ -210,7 +210,7 @@ def setcal(
if pixel:
assert c.get_Q_pixel_size() is not None, "Requested calibration not found"
if rotate:
assert c.get_QR_rotflip() is not None, "Requested calibration not found"
assert c.get_QR_rotation() is not None, "Requested calibration not found"

# set the calibrations
self._calstate = {
Expand Down Expand Up @@ -478,15 +478,15 @@ def _transform(

# Q/R rotation
if rotate:
flip = cal.get_QR_flip()
theta = cal.get_QR_rotation_degrees()
assert flip is not None, "Requested calibration was not found!"
theta = cal.get_QR_rotation()
assert theta is not None, "Requested calibration was not found!"
flip = cal.get_QR_flip()
flip = False if flip is None else flip
# rotation matrix
R = np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
)
# apply
# rotate and flip
if flip:
positions = R @ np.vstack((ans["qy"], ans["qx"]))
else:
Expand Down
31 changes: 31 additions & 0 deletions py4DSTEM/data/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def __init__(
self["R_pixel_size"] = 1
self["Q_pixel_units"] = "pixels"
self["R_pixel_units"] = "pixels"
self["QR_flip"] = False

# EMD root property
@property
Expand Down Expand Up @@ -666,8 +667,17 @@ def ellipse(self, x):

# Q/R-space rotation and flip

@call_calibrate
def set_QR_rotation(self, x):
self._params["QR_rotation"] = x
self._params["QR_rotation_degrees"] = np.degrees(x)

def get_QR_rotation(self):
return self._get_value("QR_rotation")

@call_calibrate
def set_QR_rotation_degrees(self, x):
self._params["QR_rotation"] = np.radians(x)
self._params["QR_rotation_degrees"] = x

def get_QR_rotation_degrees(self):
Expand All @@ -689,10 +699,31 @@ def set_QR_rotflip(self, rot_flip):
flip (bool): True indicates a Q/R axes flip
"""
rot, flip = rot_flip
self._params["QR_rotation"] = rot
self._params["QR_rotation_degrees"] = np.degrees(rot)
self._params["QR_flip"] = flip

@call_calibrate
def set_QR_rotflip_degrees(self, rot_flip):
"""
Args:
rot_flip (tuple), (rot, flip) where:
rot (number): rotation in degrees
flip (bool): True indicates a Q/R axes flip
"""
rot, flip = rot_flip
self._params["QR_rotation"] = np.radians(rot)
self._params["QR_rotation_degrees"] = rot
self._params["QR_flip"] = flip

def get_QR_rotflip(self):
rot = self.get_QR_rotation()
flip = self.get_QR_flip()
if rot is None or flip is None:
return None
return (rot, flip)

def get_QR_rotflip_degrees(self):
rot = self.get_QR_rotation_degrees()
flip = self.get_QR_flip()
if rot is None or flip is None:
Expand Down
Loading

0 comments on commit 2bd2fa6

Please sign in to comment.