Skip to content

Commit

Permalink
Merge branch 'main' into use_dict3
Browse files Browse the repository at this point in the history
  • Loading branch information
gduscher authored Jun 3, 2024
2 parents 3e29b7c + ad8a209 commit db41ac6
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 9 deletions.
80 changes: 72 additions & 8 deletions pyTEMlib/image_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@

from collections import Counter

# center diff function
from skimage.filters import threshold_otsu, sobel
from scipy.optimize import leastsq
from sklearn.cluster import DBSCAN


_SimpleITK_present = True
try:
Expand Down Expand Up @@ -275,6 +280,67 @@ def diffractogram_spots(dset, spot_threshold, return_center=True, eps=0.1):
return spots, center


def center_diffractogram(dset, return_plot = True, histogram_factor = None, smoothing = 1, min_samples = 100):
try:
diff = np.array(dset).T.astype(np.float16)
diff[diff < 0] = 0

if histogram_factor is not None:
hist, bins = np.histogram(np.ravel(diff), bins=256, range=(0, 1), density=True)
threshold = threshold_otsu(diff, hist = hist * histogram_factor)
else:
threshold = threshold_otsu(diff)
binary = (diff > threshold).astype(float)
smoothed_image = ndimage.gaussian_filter(binary, sigma=smoothing) # Smooth before edge detection
smooth_threshold = threshold_otsu(smoothed_image)
smooth_binary = (smoothed_image > smooth_threshold).astype(float)
# Find the edges using the Sobel operator
edges = sobel(smooth_binary)
edge_points = np.argwhere(edges)

# Use DBSCAN to cluster the edge points
db = DBSCAN(eps=10, min_samples=min_samples).fit(edge_points)
labels = db.labels_
if len(set(labels)) == 1:
raise ValueError("DBSCAN clustering resulted in only one group, check the parameters.")

# Get the largest group of edge points
unique, counts = np.unique(labels, return_counts=True)
counts = dict(zip(unique, counts))
largest_group = max(counts, key=counts.get)
edge_points = edge_points[labels == largest_group]

# Fit a circle to the diffraction ring
def calc_distance(c, x, y):
Ri = np.sqrt((x - c[0])**2 + (y - c[1])**2)
return Ri - Ri.mean()
x_m = np.mean(edge_points[:, 1])
y_m = np.mean(edge_points[:, 0])
center_guess = x_m, y_m
center, ier = leastsq(calc_distance, center_guess, args=(edge_points[:, 1], edge_points[:, 0]))
mean_radius = np.mean(calc_distance(center, edge_points[:, 1], edge_points[:, 0])) + np.sqrt((edge_points[:, 1] - center[0])**2 + (edge_points[:, 0] - center[1])**2).mean()

finally:
if return_plot:
fig, ax = plt.subplots(1, 4, figsize=(10, 4))
ax[0].set_title('Diffractogram')
ax[0].imshow(dset.T, cmap='viridis')
ax[1].set_title('Otsu Binary Image')
ax[1].imshow(binary, cmap='gray')
ax[2].set_title('Smoothed Binary Image')
ax[2].imshow(smooth_binary, cmap='gray')
ax[3].set_title('Edge Detection and Fitting')
ax[3].imshow(edges, cmap='gray')
ax[3].scatter(center[0], center[1], c='r', s=10)
circle = plt.Circle(center, mean_radius, color='red', fill=False)
ax[3].add_artist(circle)
for axis in ax:
axis.axis('off')
fig.tight_layout()

return center


def adaptive_fourier_filter(dset, spots, low_pass=3, reflection_radius=0.3):
"""
Use spots in diffractogram for a Fourier Filter
Expand Down Expand Up @@ -1069,25 +1135,23 @@ def cartesian2polar(x, y, grid, r, t, order=3):
return ndimage.map_coordinates(grid, np.array([new_ix, new_iy]), order=order).reshape(new_x.shape)


def warp(diff):
"""Takes a centered diffraction pattern (as a sidpy dataset)and warps it to a polar grid"""
"""Centered diff can be produced with it.diffractogram_spots(return_center = True)"""
def warp(diff, center):
"""Takes a diffraction pattern (as a sidpy dataset)and warps it to a polar grid"""

# Define original polar grid
nx = np.shape(diff)[0]
ny = np.shape(diff)[1]

# Define center pixel
pix2nm = np.gradient(diff.u.values)[0]
center_pixel = [abs(min(diff.u.values)), abs(min(diff.v.values))]//pix2nm

x = np.linspace(1, nx, nx, endpoint=True)-center_pixel[0]
y = np.linspace(1, ny, ny, endpoint=True)-center_pixel[1]
x = np.linspace(1, nx, nx, endpoint=True)-center[0]
y = np.linspace(1, ny, ny, endpoint=True)-center[1]
z = diff

# Define new polar grid
nr = int(min([center_pixel[0], center_pixel[1], diff.shape[0]-center_pixel[0], diff.shape[1]-center_pixel[1]])-1)
nt = 360*3
nr = int(min([center[0], center[1], diff.shape[0]-center[0], diff.shape[1]-center[1]])-1)
nt = 360 * 3

r = np.linspace(1, nr, nr)
t = np.linspace(0., np.pi, nt, endpoint=False)
Expand Down
2 changes: 1 addition & 1 deletion pyTEMlib/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
"""
_version = '0.2024.05.0'
__version__ = _version
_time = '2024-04-11 19:58:26'
_time = '2024-05-30 19:58:26'

0 comments on commit db41ac6

Please sign in to comment.