diff --git a/cca_zoo/probabilistic/__init__.py b/cca_zoo/probabilistic/__init__.py index bd3a29e7..6c6fc2ae 100644 --- a/cca_zoo/probabilistic/__init__.py +++ b/cca_zoo/probabilistic/__init__.py @@ -1,6 +1,7 @@ from ._cca import ProbabilisticCCA from ._plsregression import ProbabilisticPLSRegression from ._rcca import ProbabilisticRCCA + # from ._pls import ProbabilisticPLS __all__ = [ diff --git a/cca_zoo/probabilistic/_cca.py b/cca_zoo/probabilistic/_cca.py index 522becb1..ad717800 100644 --- a/cca_zoo/probabilistic/_cca.py +++ b/cca_zoo/probabilistic/_cca.py @@ -196,7 +196,7 @@ def _guide(self, views): n = X1.shape[0] if X1 is not None else X2.shape[0] with numpyro.plate("n", n): - z=numpyro.sample( + z = numpyro.sample( "representations", dist.MultivariateNormal( jnp.zeros(self.latent_dimensions), jnp.eye(self.latent_dimensions) diff --git a/cca_zoo/probabilistic/_rcca.py b/cca_zoo/probabilistic/_rcca.py index 7b502412..e3656773 100644 --- a/cca_zoo/probabilistic/_rcca.py +++ b/cca_zoo/probabilistic/_rcca.py @@ -144,8 +144,7 @@ def _model(self, views): "X1", dist.MultivariateNormal( z @ W1.T + mu1, - covariance_matrix=psi1 - + self.c[0] * jnp.eye(self.n_features_[0]), + covariance_matrix=psi1 + self.c[0] * jnp.eye(self.n_features_[0]), ), obs=X1, ) @@ -153,8 +152,7 @@ def _model(self, views): "X2", dist.MultivariateNormal( z @ W2.T + mu2, - covariance_matrix=psi2 - + self.c[1] * jnp.eye(self.n_features_[1]), + covariance_matrix=psi2 + self.c[1] * jnp.eye(self.n_features_[1]), ), obs=X2, ) diff --git a/test/test_probabilistic.py b/test/test_probabilistic.py index af2c218f..ee4cc0cf 100644 --- a/test/test_probabilistic.py +++ b/test/test_probabilistic.py @@ -39,4 +39,4 @@ def test_cca_vs_probabilisticCCA(setup_data): assert ( correlation > 0.9 - ), f"Expected correlation greater than 0.95, got {correlation}" \ No newline at end of file + ), f"Expected correlation greater than 0.95, got {correlation}"