Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Oct 6, 2023
2 parents 6b049ec + dd4c91b commit a969fce
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 5 deletions.
4 changes: 2 additions & 2 deletions cca_zoo/probabilistic/_pls.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def _model(self, views):
)

# Add positive-definite constraint for psi1 and psi2
psi1 = jnp.eye(self.n_features_[0])*1e-3
psi2 = jnp.eye(self.n_features_[1])*1e-3
psi1 = jnp.eye(self.n_features_[0]) * 1e-3
psi2 = jnp.eye(self.n_features_[1]) * 1e-3

mu1 = numpyro.param(
"mu_1",
Expand Down
4 changes: 2 additions & 2 deletions cca_zoo/probabilistic/_rcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,15 +144,15 @@ def _model(self, views):
"X1",
dist.MultivariateNormal(
z @ W1.T + mu1,
covariance_matrix=(1-self.c[0])*psi1,
covariance_matrix=(1 - self.c[0]) * psi1,
),
obs=X1,
)
numpyro.sample(
"X2",
dist.MultivariateNormal(
z @ W2.T + mu2,
covariance_matrix=(1-self.c[1])*psi2,
covariance_matrix=(1 - self.c[1]) * psi2,
),
obs=X2,
)
Expand Down
1 change: 0 additions & 1 deletion test/test_probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def test_cca_vs_probabilisticPLS(setup_data):
), f"Expected correlation with PLS greater than CCA, got {correlation_pls} and {correlation_cca}"



def test_cca_vs_probabilisticRidgeCCA(setup_data):
X, Y, joint = setup_data
# Initialize models with different regularization parameters
Expand Down

0 comments on commit a969fce

Please sign in to comment.