Skip to content

Commit

Permalink
Validation set for Gradient Based Models woohoo
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Oct 25, 2023
1 parent f10a473 commit 6d9e8a9
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 10 deletions.
2 changes: 2 additions & 0 deletions cca_zoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
visualisation,
preprocessing,
sequential,
nonparametric,
)

__all__ = [
Expand All @@ -16,6 +17,7 @@
"visualisation",
"preprocessing",
"sequential",
"nonparametric",
]
try:
from . import probabilistic
Expand Down
28 changes: 20 additions & 8 deletions cca_zoo/linear/_tcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,27 @@ def correlations(self, views: Iterable[np.ndarray], **kwargs):
corrs = multiplied_views / norms
return corrs

def score(self, views: Iterable[np.ndarray], **kwargs):
"""
Returns the higher order correlations in each dimension
def average_pairwise_correlations(
self, views: Iterable[np.ndarray], **kwargs
) -> np.ndarray:
transformed_views = self.transform(views, **kwargs)
transformed_views = [
transformed_view - transformed_view.mean(axis=0)
for transformed_view in transformed_views
]
multiplied_views = np.stack(transformed_views, axis=0).prod(axis=0).sum(axis=0)
norms = np.stack(
[
np.linalg.norm(transformed_view, axis=0)
for transformed_view in transformed_views
],
axis=0,
).prod(axis=0)
corrs = multiplied_views / norms
return corrs

:param views: list/tuple of numpy arrays or array likes with the same number of rows (samples)
:param kwargs: any additional keyword arguments required by the given model
"""
dim_corrs = self.correlations(views, **kwargs)
return dim_corrs
def score(self, views: Iterable[np.ndarray], **kwargs):
return self.average_pairwise_correlations(views, **kwargs).mean()

def _setup_tensor(self, views: Iterable[np.ndarray], **kwargs):
covs = [
Expand Down
2 changes: 1 addition & 1 deletion cca_zoo/nonparametric/_ncca.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sklearn.neighbors import NearestNeighbors

from cca_zoo._base import BaseModel
from cca_zoo.utils.check_values import _process_parameter
from cca_zoo._utils.check_values import _process_parameter


class NCCA(BaseModel):
Expand Down
3 changes: 2 additions & 1 deletion tests/test_unregularized.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def test_unregularized_methods(data):
CCA(latent_dimensions=latent_dims),
KCCA(latent_dimensions=latent_dims),
PCACCA(latent_dimensions=latent_dims),
TCCA(latent_dimensions=latent_dims),
KTCCA(latent_dimensions=latent_dims),
]

scores = [
Expand All @@ -57,7 +59,6 @@ def test_unregularized_multi(data):
MCCA(latent_dimensions=latent_dims, pca=False),
MCCA(latent_dimensions=latent_dims, pca=True),
KGCCA(latent_dimensions=latent_dims),
TCCA(latent_dimensions=latent_dims),
]

scores = [method.fit((X, Y, Z)).score((X, Y, Z)) for method in methods]
Expand Down

0 comments on commit 6d9e8a9

Please sign in to comment.