diff --git a/cca_zoo/probabilistic/_pls.py b/cca_zoo/probabilistic/_pls.py index 6da897ed..a7612858 100644 --- a/cca_zoo/probabilistic/_pls.py +++ b/cca_zoo/probabilistic/_pls.py @@ -79,8 +79,8 @@ def _model(self, views): ) # Add positive-definite constraint for psi1 and psi2 - psi1 = jnp.eye(self.n_features_[0])*1e-3 - psi2 = jnp.eye(self.n_features_[1])*1e-3 + psi1 = jnp.eye(self.n_features_[0]) * 1e-3 + psi2 = jnp.eye(self.n_features_[1]) * 1e-3 mu1 = numpyro.param( "mu_1", diff --git a/cca_zoo/probabilistic/_rcca.py b/cca_zoo/probabilistic/_rcca.py index 51cff806..d7edd289 100644 --- a/cca_zoo/probabilistic/_rcca.py +++ b/cca_zoo/probabilistic/_rcca.py @@ -144,7 +144,7 @@ def _model(self, views): "X1", dist.MultivariateNormal( z @ W1.T + mu1, - covariance_matrix=(1-self.c[0])*psi1, + covariance_matrix=(1 - self.c[0]) * psi1, ), obs=X1, ) @@ -152,7 +152,7 @@ def _model(self, views): "X2", dist.MultivariateNormal( z @ W2.T + mu2, - covariance_matrix=(1-self.c[1])*psi2, + covariance_matrix=(1 - self.c[1]) * psi2, ), obs=X2, ) diff --git a/test/test_probabilistic.py b/test/test_probabilistic.py index 6efdfe93..6dfa8ead 100644 --- a/test/test_probabilistic.py +++ b/test/test_probabilistic.py @@ -70,7 +70,6 @@ def test_cca_vs_probabilisticPLS(setup_data): ), f"Expected correlation with PLS greater than CCA, got {correlation_pls} and {correlation_cca}" - def test_cca_vs_probabilisticRidgeCCA(setup_data): X, Y, joint = setup_data # Initialize models with different regularization parameters