Skip to content

Commit

Permalink
Validation set for Gradient Based Models woohoo
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Oct 19, 2023
1 parent 2cf540a commit c161fcf
Show file tree
Hide file tree
Showing 21 changed files with 188 additions and 193 deletions.
72 changes: 36 additions & 36 deletions cca_zoo/deep/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,26 +266,26 @@ def loss(self, representations, independent_representations=None):
}


class CCA_SVDLoss(CCA_EYLoss):
def loss(self, representations, independent_representations=None):
C = torch.cov(torch.hstack(representations).T)
latent_dims = representations[0].shape[1]

Cxy = C[:latent_dims, latent_dims:]
Cxx = C[:latent_dims, :latent_dims]

if independent_representations is None:
Cyy = C[latent_dims:, latent_dims:]
else:
Cyy = torch.cov(independent_representations[1].T)

rewards = torch.trace(2 * Cxy)
penalties = torch.trace(Cxx @ Cyy)
return {
"objective": -rewards + penalties, # return the negative objective value
"rewards": rewards, # return the total rewards
"penalties": penalties, # return the penalties matrix
}
# class CCA_SVDLoss(CCA_EYLoss):
# def loss(self, representations, independent_representations=None):
# C = torch.cov(torch.hstack(representations).T)
# latent_dims = representations[0].shape[1]
#
# Cxy = C[:latent_dims, latent_dims:]
# Cxx = C[:latent_dims, :latent_dims]
#
# if independent_representations is None:
# Cyy = C[latent_dims:, latent_dims:]
# else:
# Cyy = torch.cov(independent_representations[1].T)
#
# rewards = torch.trace(2 * Cxy)
# penalties = torch.trace(Cxx @ Cyy)
# return {
# "objective": -rewards + penalties, # return the negative objective value
# "rewards": rewards, # return the total rewards
# "penalties": penalties, # return the penalties matrix
# }


class PLS_EYLoss(CCA_EYLoss):
Expand Down Expand Up @@ -317,19 +317,19 @@ def get_AB(self, representations, weights=None):
return A / len(representations), B / len(representations)


class PLS_SVDLoss(PLS_EYLoss):
def loss(self, representations, weights=None):
C = cross_cov(representations[0], representations[1], rowvar=False)

Cxy = C
Cxx = weights[0].T @ weights[0] / representations[0].shape[0]
Cyy = weights[1].T @ weights[1] / representations[1].shape[0]

rewards = torch.trace(2 * Cxy)
penalties = torch.trace(Cxx @ Cyy)

return {
"objective": -rewards + penalties,
"rewards": rewards,
"penalties": penalties,
}
# class PLS_SVDLoss(PLS_EYLoss):
# def loss(self, representations, weights=None):
# C = cross_cov(representations[0], representations[1], rowvar=False)
#
# Cxy = C
# Cxx = weights[0].T @ weights[0] / representations[0].shape[0]
# Cyy = weights[1].T @ weights[1] / representations[1].shape[0]
#
# rewards = torch.trace(2 * Cxy)
# penalties = torch.trace(Cxx @ Cyy)
#
# return {
# "objective": -rewards + penalties,
# "rewards": rewards,
# "penalties": penalties,
# }
4 changes: 2 additions & 2 deletions cca_zoo/linear/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ._gcca import GCCA
from ._gradient import CCA_EY, CCA_GHA, CCA_SVD, PLS_EY, PLSStochasticPower
from ._gradient import CCA_EY, CCA_GHA, PLS_EY, PLSStochasticPower
from ._grcca import GRCCA
from ._iterative import (
PLS_ALS,
Expand Down Expand Up @@ -39,6 +39,6 @@
"CCA_EY",
"PLS_EY",
"CCA_GHA",
"CCA_SVD",
# "CCA_SVD",
"PLSStochasticPower",
]
4 changes: 2 additions & 2 deletions cca_zoo/linear/_gradient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from ._gha import CCA_GHA

from ._stochasticpls import PLSStochasticPower
from ._svd import CCA_SVD
# from ._svd import CCA_SVD

__all__ = [
"CCA_EY",
"PLS_EY",
"CCA_GHA",
"CCA_SVD",
# "CCA_SVD",
# "PLS_SVD",
"PLSStochasticPower",
]
2 changes: 1 addition & 1 deletion cca_zoo/linear/_gradient/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
batch_size=None,
dataloader_kwargs=None,
epochs=1,
learning_rate=1,
learning_rate=5e-2,
initialization: Union[str, callable] = "random",
trainer_kwargs=None,
optimizer_kwargs=None,
Expand Down
1 change: 1 addition & 0 deletions cca_zoo/probabilistic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ._cca import ProbabilisticCCA
from ._plsregression import ProbabilisticPLSRegression
from ._rcca import ProbabilisticRCCA

# from ._pls import ProbabilisticPLS

__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion cca_zoo/probabilistic/_cca.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def _guide(self, views):
n = X1.shape[0] if X1 is not None else X2.shape[0]

with numpyro.plate("n", n):
z=numpyro.sample(
z = numpyro.sample(
"representations",
dist.MultivariateNormal(
jnp.zeros(self.latent_dimensions), jnp.eye(self.latent_dimensions)
Expand Down
6 changes: 2 additions & 4 deletions cca_zoo/probabilistic/_rcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,17 +144,15 @@ def _model(self, views):
"X1",
dist.MultivariateNormal(
z @ W1.T + mu1,
covariance_matrix=psi1
+ self.c[0] * jnp.eye(self.n_features_[0]),
covariance_matrix=psi1 + self.c[0] * jnp.eye(self.n_features_[0]),
),
obs=X1,
)
numpyro.sample(
"X2",
dist.MultivariateNormal(
z @ W2.T + mu2,
covariance_matrix=psi2
+ self.c[1] * jnp.eye(self.n_features_[1]),
covariance_matrix=psi2 + self.c[1] * jnp.eye(self.n_features_[1]),
),
obs=X2,
)
Expand Down
2 changes: 1 addition & 1 deletion cca_zoo/visualisation/umap_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def _validate_plot_params(self):
check_umap_support("UMAPScoreDisplay")
check_seaborn_support("TSNEScoreDisplay")

def plot(self):
def plot(self, **kwargs):
self._validate_plot_params()
import umap
import matplotlib.pyplot as plt
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ pytest-cov = "*"
seaborn = "*"
opentsne = "*"
umap-learn = "*"
rdata = "*"


[build-system]
Expand Down
15 changes: 9 additions & 6 deletions test/test_cross_corr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,25 @@

from cca_zoo.utils.cross_correlation import cross_corrcoef, cross_cov

N = 50
features = [4, 6]


def test_crosscorrcoef():
X = np.random.rand(100, 5)
Y = np.random.rand(100, 5) / 10
X = np.random.rand(N, features[0])
Y = np.random.rand(N, features[1]) / 10

m = np.corrcoef(X, Y, rowvar=False)[:5, 5:]
m = np.corrcoef(X, Y, rowvar=False)[:4, 4:]
n = cross_corrcoef(X, Y, rowvar=False)

assert np.allclose(m, n)


def test_crosscov(bias=False):
X = np.random.rand(100, 5)
Y = np.random.rand(100, 5) / 10
X = np.random.rand(N, features[0])
Y = np.random.rand(N, features[1]) / 10

m = np.cov(X, Y, rowvar=False)[:5, 5:]
m = np.cov(X, Y, rowvar=False)[:4, 4:]
n = cross_cov(X, Y, rowvar=False)

assert np.allclose(m, n)
8 changes: 4 additions & 4 deletions test/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def test_cca_on_simulated_data_maintains_expected_correlation(
# Additional test to verify the shape of generated data
def test_simulated_data_shapes():
data = JointData(
view_features=[10, 12], latent_dims=4, correlation=[0.8, 0.7, 0.6, 0.5]
view_features=[4, 6], latent_dims=4, correlation=[0.8, 0.7, 0.6, 0.5]
)
x_train, y_train = data.sample(500)
assert x_train.shape == (500, 10)
assert y_train.shape == (500, 12)
x_train, y_train = data.sample(5)
assert x_train.shape == (5, 4)
assert y_train.shape == (5, 6)
Loading

0 comments on commit c161fcf

Please sign in to comment.