diff --git a/cca_zoo/probabilistic/__init__.py b/cca_zoo/probabilistic/__init__.py index 57b90237..16ea1e83 100644 --- a/cca_zoo/probabilistic/__init__.py +++ b/cca_zoo/probabilistic/__init__.py @@ -2,6 +2,4 @@ from ._pls import ProbabilisticPLS from ._rcca import ProbabilisticRCCA -__all__ = ["ProbabilisticCCA", - "ProbabilisticPLS", - "ProbabilisticRCCA"] +__all__ = ["ProbabilisticCCA", "ProbabilisticPLS", "ProbabilisticRCCA"] diff --git a/cca_zoo/probabilistic/_cca.py b/cca_zoo/probabilistic/_cca.py index c79ecffc..4109b21b 100644 --- a/cca_zoo/probabilistic/_cca.py +++ b/cca_zoo/probabilistic/_cca.py @@ -126,7 +126,9 @@ def fit_mcmc(self, views: Iterable[np.ndarray], y=None): """ views = self._validate_data(views) nuts_kernel = NUTS(self._model) - mcmc = MCMC(nuts_kernel, num_warmup=self.num_warmup, num_samples=self.num_samples) + mcmc = MCMC( + nuts_kernel, num_warmup=self.num_warmup, num_samples=self.num_samples + ) mcmc.run(self.rng_key, views) self.params = mcmc.get_samples() return self diff --git a/cca_zoo/probabilistic/_pls.py b/cca_zoo/probabilistic/_pls.py index 2cd36f1f..c2b4546f 100644 --- a/cca_zoo/probabilistic/_pls.py +++ b/cca_zoo/probabilistic/_pls.py @@ -6,6 +6,7 @@ from cca_zoo.probabilistic._cca import ProbabilisticCCA import numpy as np + class ProbabilisticPLS(ProbabilisticCCA): """ Probabilistic Ridge Canonical Correlation Analysis (Probabilistic Ridge CCA). @@ -134,4 +135,4 @@ def joint(self): # Construct the matrix using the blocks matrix = np.block([[top_left, top_right], [bottom_left, bottom_right]]) - return matrix \ No newline at end of file + return matrix diff --git a/cca_zoo/probabilistic/_rcca.py b/cca_zoo/probabilistic/_rcca.py index f161bbf9..cfb9cd5e 100644 --- a/cca_zoo/probabilistic/_rcca.py +++ b/cca_zoo/probabilistic/_rcca.py @@ -6,6 +6,7 @@ from cca_zoo.probabilistic._cca import ProbabilisticCCA import numpy as np + class ProbabilisticRCCA(ProbabilisticCCA): """ Probabilistic Ridge Canonical Correlation Analysis (Probabilistic Ridge CCA). @@ -100,8 +101,12 @@ def _model(self, views): ) # Add positive-definite constraint for psi1 and psi2 - psi1 = numpyro.param("psi_1", jnp.eye(self.n_features_[0])) + self.c * jnp.eye(self.n_features_[0]) - psi2 = numpyro.param("psi_2", jnp.eye(self.n_features_[1])) + self.c * jnp.eye(self.n_features_[1]) + psi1 = numpyro.param("psi_1", jnp.eye(self.n_features_[0])) + self.c * jnp.eye( + self.n_features_[0] + ) + psi2 = numpyro.param("psi_2", jnp.eye(self.n_features_[1])) + self.c * jnp.eye( + self.n_features_[1] + ) mu1 = numpyro.param( "mu_1", @@ -148,12 +153,16 @@ def _model(self, views): def joint(self): # Calculate the individual matrix blocks - top_left = self.params["W_1"] @ self.params["W_1"].T + self.c*jnp.eye(self.n_features_[0]) - bottom_right = self.params["W_2"] @ self.params["W_2"].T + self.c*jnp.eye(self.n_features_[1]) + top_left = self.params["W_1"] @ self.params["W_1"].T + self.c * jnp.eye( + self.n_features_[0] + ) + bottom_right = self.params["W_2"] @ self.params["W_2"].T + self.c * jnp.eye( + self.n_features_[1] + ) top_right = self.params["W_1"] @ self.params["W_2"].T bottom_left = self.params["W_2"] @ self.params["W_1"].T # Construct the matrix using the blocks matrix = np.block([[top_left, top_right], [bottom_left, bottom_right]]) - return matrix \ No newline at end of file + return matrix diff --git a/test/test_probabilistic.py b/test/test_probabilistic.py index 2db7bc64..a73256a8 100644 --- a/test/test_probabilistic.py +++ b/test/test_probabilistic.py @@ -22,6 +22,7 @@ def setup_data(): Y -= Y.mean(axis=0) return X, Y + def test_cca_vs_probabilisticCCA(setup_data): X, Y = setup_data # Models and fit @@ -38,9 +39,10 @@ def test_cca_vs_probabilisticCCA(setup_data): correlation = correlation_matrix[0, 1] assert ( - correlation > 0.95 + correlation > 0.95 ), f"Expected correlation greater than 0.95, got {correlation}" + def test_cca_vs_probabilisticPLS(setup_data): X, Y = setup_data # Models and fit @@ -64,12 +66,13 @@ def test_cca_vs_probabilisticPLS(setup_data): correlation_cca = correlation_matrix[0, 1] assert ( - correlation_pls > correlation_cca + correlation_pls > correlation_cca ), f"Expected correlation with PLS greater than CCA, got {correlation_pls} and {correlation_cca}" assert ( - correlation_pls > 0.95 + correlation_pls > 0.95 ), f"Expected correlation greater than 0.85, got {correlation_pls}" + def test_cca_vs_probabilisticRidgeCCA(setup_data): X, Y = setup_data # Initialize models with different regularization parameters