Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added option to use the inner product as a dissimilarity measure. #395

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
11a0f8c
Added option to use the inner product as a dissimilarity measure.
Jun 9, 2024
9e8f18e
Added option to use the inner product as a dissimilarity measure.
Jun 13, 2024
58fd5a9
Added option to use the inner product as a dissimilarity measure.
Jun 13, 2024
b458b05
Added option to use the inner product as a dissimilarity measure.
Jun 13, 2024
6f2c473
Added option to use the inner product as a dissimilarity measure.
Jun 13, 2024
b282f96
Added option to use the inner product as a dissimilarity measure.
Jun 13, 2024
0910d89
Experimenting with using all-ones and all-zeros with the whitened cos…
Jul 16, 2024
d524aab
Now removing the all-zero rows from the covariance matrix (correspond…
Jul 17, 2024
8e613e7
Now removing the all-zero rows from the covariance matrix (correspond…
Jul 17, 2024
0af8260
Commit before changing branch
Jul 30, 2024
0445051
Commit before changing branch
Jul 30, 2024
2eedeeb
Trying just the diagonals
Aug 1, 2024
aba8898
Now trying the framed RSA again, using diagonal covariance matrix.
Aug 3, 2024
5736b8a
Now trying the framed RSA again, using diagonal covariance matrix.
Aug 3, 2024
b2c341a
Now trying the framed RSA again, using diagonal covariance matrix.
Aug 3, 2024
00c5539
Now trying the framed RSA again, using diagonal covariance matrix.
Aug 3, 2024
34a32a0
Now trying the framed RSA again, using diagonal covariance matrix.
Aug 3, 2024
d4f0d7f
Now trying the framed RSA again, using diagonal covariance matrix.
Aug 11, 2024
917c9a9
Trying another way of correcting the covariance.
Aug 12, 2024
2263314
Trying another way of correcting the covariance.
Aug 12, 2024
93232b0
Trying another way of correcting the covariance.
Aug 15, 2024
ff56b3c
Trying all the framed RSA options with no covariance adjustment: does…
Aug 21, 2024
6c1f1eb
Comparing constant c values across a wide range
Aug 31, 2024
b2bd269
Testing whether fixing the dof for the noise covariance fixes our pro…
Sep 4, 2024
7e42934
Testing whether fixing the dof for the noise covariance fixes our pro…
Sep 4, 2024
fb8acbe
Running more of the tuning c-norm tests for framed RSA.
Sep 5, 2024
e56add1
Running more of the tuning c-norm tests for framed RSA.
Sep 5, 2024
95c218a
Computing covariance using all available data for each subject.
Sep 15, 2024
2b80b50
Fixed the calculation of the RDM covariance matrix for the frozen pat…
Dec 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 33 additions & 17 deletions src/rsatoolbox/data/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,18 @@ def _check_demean(matrix):
matrix = matrix - np.mean(matrix, axis=0, keepdims=True)
dof = matrix.shape[0] - 1
elif matrix.ndim == 3:
matrix -= np.mean(matrix, axis=2, keepdims=True)
dof = (matrix.shape[0] - 1) * matrix.shape[2]
matrix -= np.nanmean(matrix, axis=2, keepdims=True)
# dof = (matrix.shape[0] - 1) * matrix.shape[2] JohnMark checking
dof = matrix.shape[0] * (matrix.shape[2] - 1)
matrix = matrix.transpose(0, 2, 1).reshape(
matrix.shape[0] * matrix.shape[2], matrix.shape[1])
matrix = matrix[~np.isnan(matrix).any(axis=1)]
else:
raise ValueError('Matrix for covariance estimation has wrong # of dimensions!')
return matrix, dof


def _estimate_covariance(matrix, dof, method):
def _estimate_covariance(matrix, dof, method, lamb_opt=None):
""" calls the right covariance estimation function based on the ""method" argument

Args:
Expand All @@ -66,7 +68,7 @@ def _estimate_covariance(matrix, dof, method):
if method == 'shrinkage_eye':
cov_mat = _covariance_eye(matrix, dof)
elif method == 'shrinkage_diag':
cov_mat = _covariance_diag(matrix, dof)
cov_mat = _covariance_diag(matrix, dof, lamb_opt)
elif method == 'diag':
cov_mat = _variance(matrix, dof)
elif method == 'full':
Expand Down Expand Up @@ -147,13 +149,13 @@ def _covariance_eye(matrix, dof):
b2 = min(d2, b2)
# shrink covariance matrix
s_shrink = b2 / d2 * m * np.eye(s.shape[0]) \
+ (d2-b2) / d2 * s
+ (d2 - b2) / d2 * s
# correction for degrees of freedom
s_shrink = s_shrink * matrix.shape[0] / dof
return s_shrink


def _covariance_diag(matrix, dof, mem_threshold=(10**9)/8):
def _covariance_diag(matrix, dof, lamb_opt=None, mem_threshold=(10 ** 9) / 8):
"""
computes the sample covariance matrix from a 2d-array.
matrix should be demeaned before!
Expand Down Expand Up @@ -186,15 +188,29 @@ def _covariance_diag(matrix, dof, mem_threshold=(10**9)/8):
s = s_sum / dof
var = np.diag(s)
std = np.sqrt(var)
s_mean = s_sum / np.expand_dims(std, 0) / np.expand_dims(std, 1) / (matrix.shape[0] - 1)
s2_mean = s2_sum / np.expand_dims(var, 0) / np.expand_dims(var, 1) / (matrix.shape[0] - 1)
var_hat = matrix.shape[0] / dof ** 2 \
* (s2_mean - s_mean ** 2)
mask = ~np.eye(s.shape[0], dtype=bool)
lamb = np.sum(var_hat[mask]) / np.sum(s_mean[mask] ** 2)
lamb = max(min(lamb, 1), 0)
scaling = np.eye(s.shape[0]) + (1-lamb) * mask
if lamb_opt is None:
# s_mean = s_sum / np.expand_dims(std, 0) / np.expand_dims(std, 1) / (matrix.shape[0] - 1) JohnMark check
# s2_mean = s2_sum / np.expand_dims(var, 0) / np.expand_dims(var, 1) / (matrix.shape[0] - 1)
s_mean = s_sum / np.expand_dims(std, 0) / np.expand_dims(std, 1) / dof
s2_mean = s2_sum / np.expand_dims(var, 0) / np.expand_dims(var, 1) / dof
var_hat = matrix.shape[0] / dof ** 2 \
* (s2_mean - s_mean ** 2)
lamb = np.sum(var_hat[mask]) / np.sum(s_mean[mask] ** 2)
lamb = max(min(lamb, 1), 0)
else:
lamb = lamb_opt
scaling = np.eye(s.shape[0]) + (1 - lamb) * mask
s_shrink = s * scaling

mean_shrunk_eigenvalue = np.mean(np.linalg.eigvals(s_shrink))
mean_full_eigenvalue = np.mean(np.linalg.eigvals(s))
print(f"data shape: {matrix.shape}")
print(f"dof: {dof}")
print(f"lambda: {lamb}")
print(f"mean var: {np.mean(var)}")
print(f"mean full eigenvalue: {mean_full_eigenvalue}")
print(f"mean shrunk eigenvalue: {mean_shrunk_eigenvalue}")
return s_shrink


Expand Down Expand Up @@ -272,7 +288,7 @@ def prec_from_residuals(residuals, dof=None, method='shrinkage_diag'):
return prec


def cov_from_measurements(dataset, obs_desc, dof=None, method='shrinkage_diag'):
def cov_from_measurements(dataset, obs_desc, dof=None, method='shrinkage_diag', lamb_opt=None):
"""
Estimates a covariance matrix from measurements. Allows for shrinkage estimates.
Use 'method' to choose which estimation method is used.
Expand Down Expand Up @@ -311,11 +327,11 @@ def cov_from_measurements(dataset, obs_desc, dof=None, method='shrinkage_diag'):
"obs_desc not contained in the dataset's obs_descriptors"
tensor, _ = dataset.get_measurements_tensor(obs_desc)
# calculate sample covariance matrix s
cov_mat = _estimate_covariance(tensor, dof, method)
cov_mat = _estimate_covariance(tensor, dof, method, lamb_opt)
return cov_mat


def prec_from_measurements(dataset, obs_desc, dof=None, method='shrinkage_diag'):
def prec_from_measurements(dataset, obs_desc, dof=None, method='shrinkage_diag', lamb_opt=None):
"""
Estimates the covariance matrix from measurements and finds its multiplicative
inverse (= the precision matrix)
Expand All @@ -337,7 +353,7 @@ def prec_from_measurements(dataset, obs_desc, dof=None, method='shrinkage_diag')
numpy.ndarray (or list): sigma_p: precision matrix over channels

"""
cov = cov_from_measurements(dataset, obs_desc, dof=dof, method=method)
cov = cov_from_measurements(dataset, obs_desc, dof=dof, method=method, lamb_opt=lamb_opt)
if not isinstance(cov, np.ndarray):
prec = [None] * len(cov)
for i, cov_i in enumerate(cov):
Expand Down
109 changes: 92 additions & 17 deletions src/rsatoolbox/rdm/calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ def calc_rdm(
elif method == 'crossnobis':
rdm = calc_rdm_crossnobis(dataset, descriptor, noise,
cv_descriptor, remove_mean)
elif method == 'dotproduct':
rdm = calc_rdm_dotproduct(dataset, descriptor)
elif method == 'mean_profile':
rdm = calc_rdm_mean_profile(dataset, descriptor)
elif method == 'norm_profile':
rdm = calc_rdm_norm_profile(dataset, descriptor)
elif method == 'poisson':
rdm = calc_rdm_poisson(dataset, descriptor,
prior_lambda=prior_lambda,
Expand All @@ -105,8 +111,12 @@ def calc_rdm(
prior_weight=prior_weight)
else:
raise NotImplementedError
if descriptor is not None:
if (descriptor is not None) and (method not in ['mean_profile', 'norm_profile']):
rdm.sort_by(**{descriptor: 'alpha'})
else:
desc = np.unique(np.array(dataset.obs_descriptors[descriptor]))
inds = desc.argsort()
rdm = rdm[inds]
return rdm


Expand Down Expand Up @@ -209,9 +219,9 @@ def calc_rdm_euclidean(
rsatoolbox.rdm.rdms.RDMs: RDMs object with the one RDM
"""
measurements, desc = _parse_input(dataset, descriptor, remove_mean)
sum_sq_measurements = np.sum(measurements**2, axis=1, keepdims=True)
sum_sq_measurements = np.sum(measurements ** 2, axis=1, keepdims=True)
rdm = sum_sq_measurements + sum_sq_measurements.T \
- 2 * np.dot(measurements, measurements.T)
- 2 * np.dot(measurements, measurements.T)
rdm = _extract_triu_(rdm) / measurements.shape[1]
return _build_rdms(rdm, dataset, 'squared euclidean', descriptor, desc)

Expand Down Expand Up @@ -269,7 +279,7 @@ def calc_rdm_mahalanobis(dataset, descriptor=None, noise=None, remove_mean: bool
noise = _check_noise(noise, dataset.n_channel)
kernel = measurements @ noise @ measurements.T
rdm = np.expand_dims(np.diag(kernel), 0) + \
np.expand_dims(np.diag(kernel), 1) - 2 * kernel
np.expand_dims(np.diag(kernel), 1) - 2 * kernel
rdm = _extract_triu_(rdm) / measurements.shape[1]
return _build_rdms(
rdm,
Expand All @@ -281,6 +291,71 @@ def calc_rdm_mahalanobis(dataset, descriptor=None, noise=None, remove_mean: bool
)


def calc_rdm_dotproduct(
dataset: DatasetBase,
descriptor: Optional[str] = None,
remove_mean: bool = False):
"""
Args:
dataset (rsatoolbox.data.DatasetBase):
The dataset the RDM is computed from
descriptor (String):
obs_descriptor used to define the rows/columns of the RDM
defaults to one row/column per row in the dataset
remove_mean (bool):
whether the mean of each pattern shall be removed
before calculating dotproducts.
Returns:
rsatoolbox.rdm.rdms.RDMs: RDMs object with the one RDM
"""
measurements, desc = _parse_input(dataset, descriptor, remove_mean)
rdm = measurements @ measurements.T
rdm = _extract_triu_(rdm)
return _build_rdms(rdm, dataset, 'dotproduct', descriptor, desc)


def calc_rdm_mean_profile(
dataset: DatasetBase,
descriptor: Optional[str] = None):
"""
Args:
dataset (rsatoolbox.data.DatasetBase):
The dataset the RDM is computed from
descriptor (String):
obs_descriptor used to define the rows/columns of the RDM
defaults to one row/column per row in the dataset
remove_mean (bool):
whether the mean of each pattern shall be removed
before calculating dotproducts.
Returns:
rsatoolbox.rdm.rdms.RDMs: RDMs object with the one RDM
"""
measurements, desc = _parse_input(dataset, descriptor, remove_mean=False)
measurements = measurements.mean(axis=1)
return measurements


def calc_rdm_norm_profile(
dataset: DatasetBase,
descriptor: Optional[str] = None):
"""
Args:
dataset (rsatoolbox.data.DatasetBase):
The dataset the RDM is computed from
descriptor (String):
obs_descriptor used to define the rows/columns of the RDM
defaults to one row/column per row in the dataset
remove_mean (bool):
whether the mean of each pattern shall be removed
before calculating dotproducts.
Returns:
rsatoolbox.rdm.rdms.RDMs: RDMs object with the one RDM
"""
measurements, desc = _parse_input(dataset, descriptor, remove_mean=False)
measurements = np.linalg.norm(measurements, axis=1)
return measurements


def calc_rdm_crossnobis(dataset, descriptor, noise=None,
cv_descriptor=None, remove_mean: bool = False):
"""
Expand Down Expand Up @@ -371,7 +446,7 @@ def calc_rdm_crossnobis(dataset, descriptor, noise=None,
measurements[i_fold], measurements[j_fold],
np.linalg.inv(
(variances[i_fold] + variances[j_fold]) / 2)
)
)
rdms.append(rdm)
rdms = np.array(rdms)
rdm = np.einsum('ij->j', rdms) / rdms.shape[0]
Expand Down Expand Up @@ -406,10 +481,10 @@ def calc_rdm_poisson(dataset, descriptor=None, prior_lambda=1,
"""
measurements, desc = _parse_input(dataset, descriptor)
measurements = (measurements + prior_lambda * prior_weight) \
/ (1 + prior_weight)
/ (1 + prior_weight)
kernel = measurements @ np.log(measurements).T
rdm = np.expand_dims(np.diag(kernel), 0) + \
np.expand_dims(np.diag(kernel), 1) - kernel - kernel.T
np.expand_dims(np.diag(kernel), 1) - kernel - kernel.T
rdm = _extract_triu_(rdm) / measurements.shape[1]
return _build_rdms(rdm, dataset, 'poisson', descriptor, desc)

Expand Down Expand Up @@ -455,21 +530,21 @@ def calc_rdm_poisson_cv(dataset, descriptor=None, prior_lambda=1,
measurements_test, _, _ = average_dataset_by(data_test, descriptor)
measurements_train = (measurements_train
+ prior_lambda * prior_weight) \
/ (1 + prior_weight)
/ (1 + prior_weight)
measurements_test = (measurements_test
+ prior_lambda * prior_weight) \
/ (1 + prior_weight)
/ (1 + prior_weight)
kernel = measurements_train @ np.log(measurements_test).T
rdm = np.expand_dims(np.diag(kernel), 0) + \
np.expand_dims(np.diag(kernel), 1) - kernel - kernel.T
np.expand_dims(np.diag(kernel), 1) - kernel - kernel.T
rdm = _extract_triu_(rdm) / measurements_train.shape[1]
return _build_rdms(rdm, dataset, 'poisson_cv', descriptor)


def _calc_rdm_crossnobis_single(meas1, meas2, noise) -> NDArray:
kernel = meas1 @ noise @ meas2.T
rdm = np.expand_dims(np.diag(kernel), 0) + \
np.expand_dims(np.diag(kernel), 1) - kernel - kernel.T
np.expand_dims(np.diag(kernel), 1) - kernel - kernel.T
return _extract_triu_(rdm) / meas1.shape[1]


Expand All @@ -481,8 +556,8 @@ def _gen_default_cv_descriptor(dataset, descriptor) -> np.ndarray:
desc = dataset.obs_descriptors[descriptor]
values, counts = np.unique(desc, return_counts=True)
assert np.all(counts == counts[0]), (
'cv_descriptor generation failed:\n'
+ 'different number of observations per pattern')
'cv_descriptor generation failed:\n'
+ 'different number of observations per pattern')
n_repeats = counts[0]
cv_descriptor = np.zeros_like(desc)
for i_val in values:
Expand All @@ -491,10 +566,10 @@ def _gen_default_cv_descriptor(dataset, descriptor) -> np.ndarray:


def _parse_input(
dataset: DatasetBase,
descriptor: Optional[str],
remove_mean: bool = False
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
dataset: DatasetBase,
descriptor: Optional[str],
remove_mean: bool = False
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
if descriptor is None:
measurements = dataset.measurements
desc = None
Expand Down
Loading