Skip to content

Commit

Permalink
Validation set for Gradient Based Models woohoo
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Oct 6, 2023
1 parent 0dc556d commit ca65398
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 47 deletions.
4 changes: 2 additions & 2 deletions cca_zoo/probabilistic/_pls.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def _model(self, views):
)

# Add positive-definite constraint for psi1 and psi2
psi1 = jnp.eye(self.n_features_[0])
psi2 = jnp.eye(self.n_features_[1])
psi1 = jnp.eye(self.n_features_[0])*1e-3
psi2 = jnp.eye(self.n_features_[1])*1e-3

mu1 = numpyro.param(
"mu_1",
Expand Down
4 changes: 2 additions & 2 deletions cca_zoo/probabilistic/_rcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,15 +144,15 @@ def _model(self, views):
"X1",
dist.MultivariateNormal(
z @ W1.T + mu1,
covariance_matrix=psi1 + self.c[0] * jnp.eye(self.n_features_[0]),
covariance_matrix=(1-self.c[0])*psi1,
),
obs=X1,
)
numpyro.sample(
"X2",
dist.MultivariateNormal(
z @ W2.T + mu2,
covariance_matrix=psi2 + self.c[1] * jnp.eye(self.n_features_[1]),
covariance_matrix=(1-self.c[1])*psi2,
),
obs=X2,
)
Expand Down
10 changes: 5 additions & 5 deletions cca_zoo/visualisation/scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _create_plot(self, x, y, hue, alpha=None, palette=None):
ax,
)

def plot(self):
def plot(self, title=None):
dimensions = self.scores[0].shape[1]
self.figures_ = []

Expand All @@ -188,7 +188,7 @@ def plot(self):
ax.set_xlabel(self.ax_labels[0])
ax.set_ylabel(self.ax_labels[1])
# if g is a jointplot, get the underlying figure
plt.suptitle(f"Latent Dimension {i+1}")
plt.suptitle(f"{title} Latent Dimension {i+1}")

if self.show_corr:
if self.test_scores is None:
Expand Down Expand Up @@ -226,7 +226,7 @@ def _create_plot(self, x, y, hue=None, palette=None):


class SeparateScoreScatterDisplay(ScoreScatterDisplay):
def plot(self):
def plot(self, title=None):
dimensions = self.scores[0].shape[1]
self.train_figures_ = []
self.test_figures_ = []
Expand All @@ -251,7 +251,7 @@ def plot(self):
transform=ax.transAxes,
verticalalignment="top",
)
plt.suptitle(f"Test - Latent Dimension {i+1}")
plt.suptitle(f"{title} Train - Latent Dimension {i+1}")
self.train_figures_.append(g)

g, fig, ax = self._create_plot(
Expand All @@ -273,7 +273,7 @@ def plot(self):
transform=ax.transAxes,
verticalalignment="top",
)
plt.suptitle(f"Test - Latent Dimension {i+1}")
plt.suptitle(f"{title} Test - Latent Dimension {i+1}")
self.test_figures_.append(g)
plt.tight_layout()
return self
Expand Down
6 changes: 0 additions & 6 deletions docs/source/examples/plot_dcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
dcca, val_loader, labels=val_labels.astype(str)
)
score_display.plot()
score_display.figure_.suptitle("Scatter Deep CCA")
plt.show()

# UMAP Visualization
Expand Down Expand Up @@ -106,7 +105,6 @@
dcca_eg, val_loader, labels=val_labels.astype(str)
)
score_display.plot()
score_display.figure_.suptitle("Deep CCA EY")
plt.show()

# %%
Expand All @@ -125,7 +123,6 @@
dcca_noi, val_loader, labels=val_labels.astype(str)
)
score_display.plot()
score_display.figure_.suptitle("Deep CCA by Non-Linear Orthogonal Iterations")
plt.show()

# %%
Expand All @@ -146,7 +143,6 @@
dcca_sdl, val_loader, labels=val_labels.astype(str)
)
score_display.plot()
score_display.figure_.suptitle("Deep CCA by Stochastic Decorrelation Loss")
plt.show()

# %%
Expand All @@ -168,7 +164,6 @@
barlowtwins, val_loader, labels=val_labels.astype(str)
)
score_display.plot()
score_display.figure_.suptitle("Deep CCA by Barlow Twins")
plt.show()

# %%
Expand All @@ -190,5 +185,4 @@
dcca_vicreg, val_loader, labels=val_labels.astype(str)
)
score_display.plot()
score_display.figure_.suptitle("Deep CCA by VICReg")
plt.show()
64 changes: 32 additions & 32 deletions test/test_probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def setup_data():
latent_dims=latent_dims,
random_state=seed,
)
X, Y = data.sample(100)
X, Y = data.sample(50)
X -= X.mean(axis=0)
Y -= Y.mean(axis=0)
return X, Y, data.joint_cov
Expand Down Expand Up @@ -70,34 +70,34 @@ def test_cca_vs_probabilisticPLS(setup_data):
), f"Expected correlation with PLS greater than CCA, got {correlation_pls} and {correlation_cca}"


#
# def test_cca_vs_probabilisticRidgeCCA(setup_data):
# X, Y, joint = setup_data
# # Initialize models with different regularization parameters
# prcca_pls = ProbabilisticRCCA(latent_dimensions=1, random_state=10, c=1.0)
# prcca_cca = ProbabilisticRCCA(latent_dimensions=1, random_state=10, c=0)
# # Fit and Transform using ProbabilisticRCCA with large and small regularization
# prcca_cca.fit([X, Y])
# prcca_pls.fit([X, Y])
#
# z_ridge_cca = np.array(prcca_cca.transform([X, None]))
# z_ridge_pls = np.array(prcca_pls.transform([X, None]))
#
# # Fit and Transform using classical CCA and PLS
# cca = CCA(latent_dimensions=1)
# pls = PLS(latent_dimensions=1)
#
# cca.fit([X, Y])
# pls.fit([X, Y])
#
# z_cca = np.array(cca.transform([X, Y])[0])
# z_pls = np.array(pls.transform([X, Y])[0])
#
# # Assert: Correlations should be high when ProbabilisticRCCA approximates CCA and PLS
# corr_matrix_cca = np.abs(np.corrcoef(z_cca.reshape(-1), z_ridge_cca.reshape(-1)))
# corr_cca = corr_matrix_cca[0, 1]
# assert corr_cca > 0.9, f"Expected correlation greater than 0.9, got {corr_cca}"
#
# corr_matrix_pls = np.abs(np.corrcoef(z_pls.reshape(-1), z_ridge_pls.reshape(-1)))
# corr_pls = corr_matrix_pls[0, 1]
# assert corr_pls > 0.9, f"Expected correlation greater than 0.9, got {corr_pls}"

def test_cca_vs_probabilisticRidgeCCA(setup_data):
X, Y, joint = setup_data
# Initialize models with different regularization parameters
prcca_pls = ProbabilisticRCCA(latent_dimensions=1, random_state=10, c=0.999)
prcca_cca = ProbabilisticRCCA(latent_dimensions=1, random_state=10, c=0)
# Fit and Transform using ProbabilisticRCCA with large and small regularization
prcca_cca.fit([X, Y])
prcca_pls.fit([X, Y])

z_ridge_cca = np.array(prcca_cca.transform([X, None]))
z_ridge_pls = np.array(prcca_pls.transform([X, None]))

# Fit and Transform using classical CCA and PLS
cca = CCA(latent_dimensions=1)
pls = PLS(latent_dimensions=1)

cca.fit([X, Y])
pls.fit([X, Y])

z_cca = np.array(cca.transform([X, Y])[0])
z_pls = np.array(pls.transform([X, Y])[0])

# Assert: Correlations should be high when ProbabilisticRCCA approximates CCA and PLS
corr_matrix_cca = np.abs(np.corrcoef(z_cca.reshape(-1), z_ridge_cca.reshape(-1)))
corr_cca = corr_matrix_cca[0, 1]
assert corr_cca > 0.9, f"Expected correlation greater than 0.9, got {corr_cca}"

corr_matrix_pls = np.abs(np.corrcoef(z_pls.reshape(-1), z_ridge_pls.reshape(-1)))
corr_pls = corr_matrix_pls[0, 1]
assert corr_pls > 0.9, f"Expected correlation greater than 0.9, got {corr_pls}"

0 comments on commit ca65398

Please sign in to comment.