Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get correct centers of mass from rois.com even when swap_dim is True in get_contours #1370

Merged
merged 7 commits into from
Jul 9, 2024
Merged
48 changes: 19 additions & 29 deletions caiman/base/rois.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,46 +28,36 @@
pass


def com(A: np.ndarray, d1: int, d2: int, d3: Optional[int] = None) -> np.array:
def com(A, d1: int, d2: int, d3: Optional[int] = None, order: str = 'F') -> np.ndarray:
"""Calculation of the center of mass for spatial components

Args:
A: np.ndarray
matrix of spatial components (d x K)
A: np.ndarray or scipy.sparse array or matrix
matrix of spatial components (d x K).

d1: int
number of pixels in x-direction

d2: int
number of pixels in y-direction

d3: int
number of pixels in z-direction
d1, d2, d3: ints
d1, d2, and (optionally) d3 are the original dimensions of the data.

order: 'C' or 'F'
how each column of A should be reshaped to match the given dimensions.

Returns:
cm: np.ndarray
center of mass for spatial components (K x 2 or 3)
center of mass for spatial components (K x D)
"""

if 'csc_matrix' not in str(type(A)):
A = scipy.sparse.csc_matrix(A)

if d3 is None:
Coor = np.matrix([np.outer(np.ones(d2), np.arange(d1)).ravel(),
np.outer(np.arange(d2), np.ones(d1)).ravel()],
dtype=A.dtype)
else:
Coor = np.matrix([
np.outer(np.ones(d3),
np.outer(np.ones(d2), np.arange(d1)).ravel()).ravel(),
np.outer(np.ones(d3),
np.outer(np.arange(d2), np.ones(d1)).ravel()).ravel(),
np.outer(np.arange(d3),
np.outer(np.ones(d2), np.ones(d1)).ravel()).ravel()
],
dtype=A.dtype)

cm = (Coor * A / A.sum(axis=0)).T
dims = [d1, d2]
if d3 is not None:
dims.append(d3)

# make coordinate arrays where coor[d] increases from 0 to npixels[d]-1 along the dth axis
coors = np.meshgrid(*[range(d) for d in dims], indexing='ij')
coor = np.stack([c.ravel(order=order) for c in coors])

# take weighted sum of pixel positions along each coordinate
cm = (coor @ A / A.sum(axis=0)).T
return np.array(cm)


Expand Down
29 changes: 29 additions & 0 deletions caiman/tests/test_toydata.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import caiman.source_extraction.cnmf.params
from caiman.source_extraction import cnmf as cnmf
from caiman.utils.visualization import get_contours


#%%
Expand Down Expand Up @@ -64,6 +65,34 @@ def pipeline(D):
]
npt.assert_allclose(corr, 1, .05)

# Check that get_contours works regardless of swap_dim
coor_normal = get_contours(cnm.estimates.A, dims, swap_dim=False)
coor_swapped = get_contours(cnm.estimates.A, dims[::-1], swap_dim=True)
for c_normal, c_swapped in zip(coor_normal, coor_swapped):
if D == 3:
for plane_coor_normal, plane_coor_swapped in zip(c_normal['coordinates'], c_swapped['coordinates']):
compare_contour_coords(plane_coor_normal, plane_coor_swapped[:, ::-1])
else:
compare_contour_coords(c_normal['coordinates'], c_swapped['coordinates'][:, ::-1])

npt.assert_allclose(c_normal['CoM'], c_swapped['CoM'][::-1])

def compare_contour_coords(coords1: np.ndarray, coords2: np.ndarray):
"""
Compare 2 matrices of contour coordinates that should be the same, but may be calculated in a different order/
from different starting points.

The first point of each contour component is repeated, and this may be a different point depending on orientation.
To get around this, compare differences instead (have to take absolute value b/c direction may be opposite).
Also sort coordinates b/c starting point is unimportant & depends on orientation
"""
diffs_sorted = []
for coords in [coords1, coords2]:
abs_diffs = np.abs(np.diff(coords, axis=0))
sort_order = np.lexsort(abs_diffs.T)
diffs_sorted.append(abs_diffs[sort_order, :])
npt.assert_allclose(diffs_sorted[0], diffs_sorted[1])


def test_2D():
pipeline(2)
Expand Down
62 changes: 35 additions & 27 deletions caiman/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,23 +366,32 @@ def plot_unit(uid, scl):
.redim.range(unit_id=(0, nr-1), scale=(0.0, 1.0)))


def get_contours(A, dims, thr=0.9, thr_method='nrg', swap_dim=False):
def get_contours(A, dims, thr=0.9, thr_method='nrg', swap_dim=False, slice_dim: Optional[int] = None):
"""Gets contour of spatial components and returns their coordinates

Args:
A: np.ndarray or sparse matrix
Matrix of Spatial components (d x K)

dims: tuple of ints
Spatial dimensions of movie (x, y[, z])
Spatial dimensions of movie

thr: scalar between 0 and 1
Energy threshold for computing contours (default 0.9)

thr_method: [optional] string
thr_method: string
Method of thresholding:
'max' sets to zero pixels that have value less than a fraction of the max value
'nrg' keeps the pixels that contribute up to a specified fraction of the energy

swap_dim: bool
If False (default), each column of A should be reshaped in F-order to recover the mask;
this is correct if the dimensions have not been reordered from (y, x[, z]).
If True, each column should be reshaped in C-order; this is correct for dims = ([z, ]x, y).

slice_dim: int or None
Which dimension to slice along if we have 3D data. (i.e., get contours on each plane along this axis).
The default (None) is 0 if swap_dim is True, else -1.

Returns:
Coor: list of coordinates with center of mass and
Expand All @@ -392,18 +401,11 @@ def get_contours(A, dims, thr=0.9, thr_method='nrg', swap_dim=False):
if 'csc_matrix' not in str(type(A)):
A = csc_matrix(A)
d, nr = np.shape(A)
# if we are on a 3D video
if len(dims) == 3:
d1, d2, d3 = dims
x, y = np.mgrid[0:d2:1, 0:d3:1]
else:
d1, d2 = dims
x, y = np.mgrid[0:d1:1, 0:d2:1]

coordinates = []

# get the center of mass of neurons( patches )
cm = caiman.base.rois.com(A, *dims)
cm = caiman.base.rois.com(A, *dims, order='C' if swap_dim else 'F')

# for each patches
for i in range(nr):
Expand Down Expand Up @@ -437,9 +439,10 @@ def get_contours(A, dims, thr=0.9, thr_method='nrg', swap_dim=False):
Bmat = np.reshape(Bvec, dims, order='C')
else:
Bmat = np.reshape(Bvec, dims, order='F')
pars['coordinates'] = []
# for each dimensions we draw the contour
for B in (Bmat if len(dims) == 3 else [Bmat]):

def get_slice_coords(B: np.ndarray) -> np.ndarray:
"""Get contour coordinates for a 2D slice"""
d1, d2 = B.shape
vertices = find_contours(B.T, thr)
# this fix is necessary for having disjoint figures and borders plotted correctly
v = np.atleast_2d([np.nan, np.nan])
Expand All @@ -448,16 +451,26 @@ def get_contours(A, dims, thr=0.9, thr_method='nrg', swap_dim=False):
if num_close_coords < 2:
if num_close_coords == 0:
# case angle
newpt = np.round(vtx[-1, :] / [d2, d1]) * [d2, d1]
vtx = np.concatenate((vtx, newpt[np.newaxis, :]), axis=0)
newpt = np.round(np.mean(vtx[[0, -1], :], axis=0) / [d2, d1]) * [d2, d1]
vtx = np.concatenate((newpt[np.newaxis, :], vtx, newpt[np.newaxis, :]), axis=0)
else:
# case one is border
vtx = np.concatenate((vtx, vtx[0, np.newaxis]), axis=0)
v = np.concatenate(
(v, vtx, np.atleast_2d([np.nan, np.nan])), axis=0)
return v

if len(dims) == 2:
pars['coordinates'] = get_slice_coords(Bmat)
else:
# make a list of the contour coordinates for each 2D slice
pars['coordinates'] = []
if slice_dim is None:
slice_dim = 0 if swap_dim else -1
for s in range(dims[slice_dim]):
B = Bmat.take(s, axis=slice_dim)
pars['coordinates'].append(get_slice_coords(B))

pars['coordinates'] = v if len(
dims) == 2 else (pars['coordinates'] + [v])
pars['CoM'] = np.squeeze(cm[i, :])
pars['neuron_id'] = i + 1
coordinates.append(pars)
Expand Down Expand Up @@ -1098,16 +1111,11 @@ def plot_contours(A, Cn, thr=None, thr_method='max', maxthr=0.2, nrgthr=0.9, dis
plt.plot(*v.T, c=colors, **contour_args)

if display_numbers:
d1, d2 = np.shape(Cn)
d, nr = np.shape(A)
cm = caiman.base.rois.com(A, d1, d2)
nr = A.shape[1]
if max_number is None:
max_number = A.shape[1]
for i in range(np.minimum(nr, max_number)):
if swap_dim:
ax.text(cm[i, 0], cm[i, 1], str(i + 1), color=colors, **number_args)
else:
ax.text(cm[i, 1], cm[i, 0], str(i + 1), color=colors, **number_args)
max_number = nr
for i, c in zip(range(np.minimum(nr, max_number)), coordinates):
ax.text(c['CoM'][1], c['CoM'][0], str(i + 1), color=colors, **number_args)
return coordinates

def plot_shapes(Ab, dims, num_comps=15, size=(15, 15), comps_per_row=None,
Expand Down