Skip to content

Commit

Permalink
Format code with black
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 authored and github-actions[bot] committed Oct 3, 2023
1 parent cc35455 commit d7cf546
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 13 deletions.
4 changes: 1 addition & 3 deletions cca_zoo/probabilistic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,4 @@
from ._pls import ProbabilisticPLS
from ._rcca import ProbabilisticRCCA

__all__ = ["ProbabilisticCCA",
"ProbabilisticPLS",
"ProbabilisticRCCA"]
__all__ = ["ProbabilisticCCA", "ProbabilisticPLS", "ProbabilisticRCCA"]
4 changes: 3 additions & 1 deletion cca_zoo/probabilistic/_cca.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ def fit_mcmc(self, views: Iterable[np.ndarray], y=None):
"""
views = self._validate_data(views)
nuts_kernel = NUTS(self._model)
mcmc = MCMC(nuts_kernel, num_warmup=self.num_warmup, num_samples=self.num_samples)
mcmc = MCMC(
nuts_kernel, num_warmup=self.num_warmup, num_samples=self.num_samples
)
mcmc.run(self.rng_key, views)
self.params = mcmc.get_samples()
return self
Expand Down
3 changes: 2 additions & 1 deletion cca_zoo/probabilistic/_pls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from cca_zoo.probabilistic._cca import ProbabilisticCCA
import numpy as np


class ProbabilisticPLS(ProbabilisticCCA):
"""
Probabilistic Ridge Canonical Correlation Analysis (Probabilistic Ridge CCA).
Expand Down Expand Up @@ -134,4 +135,4 @@ def joint(self):
# Construct the matrix using the blocks
matrix = np.block([[top_left, top_right], [bottom_left, bottom_right]])

return matrix
return matrix
19 changes: 14 additions & 5 deletions cca_zoo/probabilistic/_rcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from cca_zoo.probabilistic._cca import ProbabilisticCCA
import numpy as np


class ProbabilisticRCCA(ProbabilisticCCA):
"""
Probabilistic Ridge Canonical Correlation Analysis (Probabilistic Ridge CCA).
Expand Down Expand Up @@ -100,8 +101,12 @@ def _model(self, views):
)

# Add positive-definite constraint for psi1 and psi2
psi1 = numpyro.param("psi_1", jnp.eye(self.n_features_[0])) + self.c * jnp.eye(self.n_features_[0])
psi2 = numpyro.param("psi_2", jnp.eye(self.n_features_[1])) + self.c * jnp.eye(self.n_features_[1])
psi1 = numpyro.param("psi_1", jnp.eye(self.n_features_[0])) + self.c * jnp.eye(
self.n_features_[0]
)
psi2 = numpyro.param("psi_2", jnp.eye(self.n_features_[1])) + self.c * jnp.eye(
self.n_features_[1]
)

mu1 = numpyro.param(
"mu_1",
Expand Down Expand Up @@ -148,12 +153,16 @@ def _model(self, views):

def joint(self):
# Calculate the individual matrix blocks
top_left = self.params["W_1"] @ self.params["W_1"].T + self.c*jnp.eye(self.n_features_[0])
bottom_right = self.params["W_2"] @ self.params["W_2"].T + self.c*jnp.eye(self.n_features_[1])
top_left = self.params["W_1"] @ self.params["W_1"].T + self.c * jnp.eye(
self.n_features_[0]
)
bottom_right = self.params["W_2"] @ self.params["W_2"].T + self.c * jnp.eye(
self.n_features_[1]
)
top_right = self.params["W_1"] @ self.params["W_2"].T
bottom_left = self.params["W_2"] @ self.params["W_1"].T

# Construct the matrix using the blocks
matrix = np.block([[top_left, top_right], [bottom_left, bottom_right]])

return matrix
return matrix
9 changes: 6 additions & 3 deletions test/test_probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def setup_data():
Y -= Y.mean(axis=0)
return X, Y


def test_cca_vs_probabilisticCCA(setup_data):
X, Y = setup_data
# Models and fit
Expand All @@ -38,9 +39,10 @@ def test_cca_vs_probabilisticCCA(setup_data):
correlation = correlation_matrix[0, 1]

assert (
correlation > 0.95
correlation > 0.95
), f"Expected correlation greater than 0.95, got {correlation}"


def test_cca_vs_probabilisticPLS(setup_data):
X, Y = setup_data
# Models and fit
Expand All @@ -64,12 +66,13 @@ def test_cca_vs_probabilisticPLS(setup_data):
correlation_cca = correlation_matrix[0, 1]

assert (
correlation_pls > correlation_cca
correlation_pls > correlation_cca
), f"Expected correlation with PLS greater than CCA, got {correlation_pls} and {correlation_cca}"
assert (
correlation_pls > 0.95
correlation_pls > 0.95
), f"Expected correlation greater than 0.85, got {correlation_pls}"


def test_cca_vs_probabilisticRidgeCCA(setup_data):
X, Y = setup_data
# Initialize models with different regularization parameters
Expand Down

0 comments on commit d7cf546

Please sign in to comment.