diff --git a/cca_zoo/probabilistic/_pls.py b/cca_zoo/probabilistic/_pls.py index c2b4546f..6da897ed 100644 --- a/cca_zoo/probabilistic/_pls.py +++ b/cca_zoo/probabilistic/_pls.py @@ -79,8 +79,8 @@ def _model(self, views): ) # Add positive-definite constraint for psi1 and psi2 - psi1 = jnp.eye(self.n_features_[0]) - psi2 = jnp.eye(self.n_features_[1]) + psi1 = jnp.eye(self.n_features_[0])*1e-3 + psi2 = jnp.eye(self.n_features_[1])*1e-3 mu1 = numpyro.param( "mu_1", diff --git a/cca_zoo/probabilistic/_rcca.py b/cca_zoo/probabilistic/_rcca.py index 232a7d56..51cff806 100644 --- a/cca_zoo/probabilistic/_rcca.py +++ b/cca_zoo/probabilistic/_rcca.py @@ -144,7 +144,7 @@ def _model(self, views): "X1", dist.MultivariateNormal( z @ W1.T + mu1, - covariance_matrix=psi1 + self.c[0] * jnp.eye(self.n_features_[0]), + covariance_matrix=(1-self.c[0])*psi1, ), obs=X1, ) @@ -152,7 +152,7 @@ def _model(self, views): "X2", dist.MultivariateNormal( z @ W2.T + mu2, - covariance_matrix=psi2 + self.c[1] * jnp.eye(self.n_features_[1]), + covariance_matrix=(1-self.c[1])*psi2, ), obs=X2, ) diff --git a/cca_zoo/visualisation/scores.py b/cca_zoo/visualisation/scores.py index b817da64..d7536500 100644 --- a/cca_zoo/visualisation/scores.py +++ b/cca_zoo/visualisation/scores.py @@ -172,7 +172,7 @@ def _create_plot(self, x, y, hue, alpha=None, palette=None): ax, ) - def plot(self): + def plot(self, title=None): dimensions = self.scores[0].shape[1] self.figures_ = [] @@ -188,7 +188,7 @@ def plot(self): ax.set_xlabel(self.ax_labels[0]) ax.set_ylabel(self.ax_labels[1]) # if g is a jointplot, get the underlying figure - plt.suptitle(f"Latent Dimension {i+1}") + plt.suptitle(f"{title} Latent Dimension {i+1}") if self.show_corr: if self.test_scores is None: @@ -226,7 +226,7 @@ def _create_plot(self, x, y, hue=None, palette=None): class SeparateScoreScatterDisplay(ScoreScatterDisplay): - def plot(self): + def plot(self, title=None): dimensions = self.scores[0].shape[1] self.train_figures_ = [] self.test_figures_ = [] @@ -251,7 +251,7 @@ def plot(self): transform=ax.transAxes, verticalalignment="top", ) - plt.suptitle(f"Test - Latent Dimension {i+1}") + plt.suptitle(f"{title} Train - Latent Dimension {i+1}") self.train_figures_.append(g) g, fig, ax = self._create_plot( @@ -273,7 +273,7 @@ def plot(self): transform=ax.transAxes, verticalalignment="top", ) - plt.suptitle(f"Test - Latent Dimension {i+1}") + plt.suptitle(f"{title} Test - Latent Dimension {i+1}") self.test_figures_.append(g) plt.tight_layout() return self diff --git a/docs/source/examples/plot_dcca.py b/docs/source/examples/plot_dcca.py index e78001fd..88566230 100644 --- a/docs/source/examples/plot_dcca.py +++ b/docs/source/examples/plot_dcca.py @@ -69,7 +69,6 @@ dcca, val_loader, labels=val_labels.astype(str) ) score_display.plot() -score_display.figure_.suptitle("Scatter Deep CCA") plt.show() # UMAP Visualization @@ -106,7 +105,6 @@ dcca_eg, val_loader, labels=val_labels.astype(str) ) score_display.plot() -score_display.figure_.suptitle("Deep CCA EY") plt.show() # %% @@ -125,7 +123,6 @@ dcca_noi, val_loader, labels=val_labels.astype(str) ) score_display.plot() -score_display.figure_.suptitle("Deep CCA by Non-Linear Orthogonal Iterations") plt.show() # %% @@ -146,7 +143,6 @@ dcca_sdl, val_loader, labels=val_labels.astype(str) ) score_display.plot() -score_display.figure_.suptitle("Deep CCA by Stochastic Decorrelation Loss") plt.show() # %% @@ -168,7 +164,6 @@ barlowtwins, val_loader, labels=val_labels.astype(str) ) score_display.plot() -score_display.figure_.suptitle("Deep CCA by Barlow Twins") plt.show() # %% @@ -190,5 +185,4 @@ dcca_vicreg, val_loader, labels=val_labels.astype(str) ) score_display.plot() -score_display.figure_.suptitle("Deep CCA by VICReg") plt.show() diff --git a/test/test_probabilistic.py b/test/test_probabilistic.py index aa57cd17..6efdfe93 100644 --- a/test/test_probabilistic.py +++ b/test/test_probabilistic.py @@ -17,7 +17,7 @@ def setup_data(): latent_dims=latent_dims, random_state=seed, ) - X, Y = data.sample(100) + X, Y = data.sample(50) X -= X.mean(axis=0) Y -= Y.mean(axis=0) return X, Y, data.joint_cov @@ -70,34 +70,34 @@ def test_cca_vs_probabilisticPLS(setup_data): ), f"Expected correlation with PLS greater than CCA, got {correlation_pls} and {correlation_cca}" -# -# def test_cca_vs_probabilisticRidgeCCA(setup_data): -# X, Y, joint = setup_data -# # Initialize models with different regularization parameters -# prcca_pls = ProbabilisticRCCA(latent_dimensions=1, random_state=10, c=1.0) -# prcca_cca = ProbabilisticRCCA(latent_dimensions=1, random_state=10, c=0) -# # Fit and Transform using ProbabilisticRCCA with large and small regularization -# prcca_cca.fit([X, Y]) -# prcca_pls.fit([X, Y]) -# -# z_ridge_cca = np.array(prcca_cca.transform([X, None])) -# z_ridge_pls = np.array(prcca_pls.transform([X, None])) -# -# # Fit and Transform using classical CCA and PLS -# cca = CCA(latent_dimensions=1) -# pls = PLS(latent_dimensions=1) -# -# cca.fit([X, Y]) -# pls.fit([X, Y]) -# -# z_cca = np.array(cca.transform([X, Y])[0]) -# z_pls = np.array(pls.transform([X, Y])[0]) -# -# # Assert: Correlations should be high when ProbabilisticRCCA approximates CCA and PLS -# corr_matrix_cca = np.abs(np.corrcoef(z_cca.reshape(-1), z_ridge_cca.reshape(-1))) -# corr_cca = corr_matrix_cca[0, 1] -# assert corr_cca > 0.9, f"Expected correlation greater than 0.9, got {corr_cca}" -# -# corr_matrix_pls = np.abs(np.corrcoef(z_pls.reshape(-1), z_ridge_pls.reshape(-1))) -# corr_pls = corr_matrix_pls[0, 1] -# assert corr_pls > 0.9, f"Expected correlation greater than 0.9, got {corr_pls}" + +def test_cca_vs_probabilisticRidgeCCA(setup_data): + X, Y, joint = setup_data + # Initialize models with different regularization parameters + prcca_pls = ProbabilisticRCCA(latent_dimensions=1, random_state=10, c=0.999) + prcca_cca = ProbabilisticRCCA(latent_dimensions=1, random_state=10, c=0) + # Fit and Transform using ProbabilisticRCCA with large and small regularization + prcca_cca.fit([X, Y]) + prcca_pls.fit([X, Y]) + + z_ridge_cca = np.array(prcca_cca.transform([X, None])) + z_ridge_pls = np.array(prcca_pls.transform([X, None])) + + # Fit and Transform using classical CCA and PLS + cca = CCA(latent_dimensions=1) + pls = PLS(latent_dimensions=1) + + cca.fit([X, Y]) + pls.fit([X, Y]) + + z_cca = np.array(cca.transform([X, Y])[0]) + z_pls = np.array(pls.transform([X, Y])[0]) + + # Assert: Correlations should be high when ProbabilisticRCCA approximates CCA and PLS + corr_matrix_cca = np.abs(np.corrcoef(z_cca.reshape(-1), z_ridge_cca.reshape(-1))) + corr_cca = corr_matrix_cca[0, 1] + assert corr_cca > 0.9, f"Expected correlation greater than 0.9, got {corr_cca}" + + corr_matrix_pls = np.abs(np.corrcoef(z_pls.reshape(-1), z_ridge_pls.reshape(-1))) + corr_pls = corr_matrix_pls[0, 1] + assert corr_pls > 0.9, f"Expected correlation greater than 0.9, got {corr_pls}"