Skip to content

Commit

Permalink
Pulling the pytorch stuff out of simple linear models as didn't add a…
Browse files Browse the repository at this point in the history
…nything and made things complex
  • Loading branch information
jameschapman19 committed Aug 7, 2023
1 parent 31bb1ed commit 6e5d85a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 61 deletions.
9 changes: 6 additions & 3 deletions cca_zoo/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,12 @@ def pairwise_correlations(self, views: Iterable[np.ndarray], **kwargs):
]
)
)
all_corrs = np.array(all_corrs).reshape(
(self.n_views_, self.n_views_, self.latent_dimensions)
)
try:
all_corrs = np.array(all_corrs).reshape(
(self.n_views_, self.n_views_, self.latent_dimensions)
)
except:
print()
return all_corrs

def score(self, views: Iterable[np.ndarray], y=None, **kwargs):
Expand Down
92 changes: 34 additions & 58 deletions docs/source/examples/plot_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,69 +6,21 @@
permutation testing, learning curves, and cross-validation.
"""

# %%
# Import libraries
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import ShuffleSplit, KFold

from cca_zoo.data.simulated import LinearSimulatedData
from cca_zoo.linear import CCA
from cca_zoo.model_selection import learning_curve, permutation_test_score

# %%
# Data
# ------
# We set the random seed for reproducibility
np.random.seed(42)

# We generate a linear dataset with 200 samples, 15 features per view,
# 3 latent dimensions and different correlations between the views
n = 200
p = 15
q = 15
latent_dims = 3
correlations = [0.9, 0.5, 0.1]

(X, Y) = LinearSimulatedData(
view_features=[p, q], latent_dims=latent_dims, correlation=correlations
).sample(n)

# %%
# Permutation Testing
# -------------------
# Permutation testing is a way to assess the significance of the model performance by comparing it with the performance on permuted data.
# We use a CCA model with 3 latent dimensions and a 2-fold cross-validation scheme.
model = CCA(latent_dimensions=latent_dims)
cv = KFold(2, shuffle=True, random_state=0)

# We use permutation_test_score to compute the score on the original data and on 100 permutations of the data.
score, perm_scores, pvalue = permutation_test_score(
model, (X, Y), cv=cv, n_permutations=100
)

# %%
# We plot the histogram of the permuted scores and the score on the original data for each dimension.
fig, ax = plt.subplots(latent_dims, figsize=[12, 8])
for k in range(latent_dims):
ax[k].hist(perm_scores[k])
ax[k].axvline(score[k], ls="--", color="r")
score_label = f"Score on original\ndata: {score[k]:.2f}\n(p-value: {pvalue[k]:.3f})"
ax[k].text(0.05, 0.8, score_label, fontsize=12, transform=ax[k].transAxes)
ax[k].set_xlabel("Correlation")
_ = ax[k].set_ylabel("Frequency")
ax[k].set_title(f"Dimension {k + 1}")
plt.tight_layout()
plt.show()


# %%
# Learning Curves
# -------------------

import matplotlib.pyplot as plt
import numpy as np

np.random.seed(42) # We set the random seed for reproducibility
n = 250 # number of samples
p = 15 # features in view 1
q = 15 # features in view 2
latent_dims = 1 # latent dimensions
correlations = [0.9] # correlations between views

def plot_learning_curve(
estimator,
Expand Down Expand Up @@ -219,15 +171,39 @@ def plot_learning_curve(

return plt

# Data generation
(X, Y) = LinearSimulatedData(
view_features=[p, q], latent_dims=latent_dims, correlation=correlations
).sample(n)

# Permutation Testing
model = CCA(latent_dimensions=latent_dims)
cv = KFold(2, shuffle=True, random_state=0)
score, perm_scores, pvalue = permutation_test_score(
model, (X, Y), cv=cv, n_permutations=100
)

# Permutation Test Visualization
fig, ax = plt.subplots(latent_dims, figsize=[12, 8])
for k in range(latent_dims):
ax.hist(perm_scores)
ax.axvline(score, ls="--", color="r")
score_label = f"Score on original\ndata: {score:.2f}\n(p-value: {pvalue:.3f})"
ax.text(0.05, 0.8, score_label, fontsize=12, transform=ax.transAxes)
ax.set_xlabel("Correlation")
_ = ax.set_ylabel("Frequency")
ax.set_title(f"Dimension {k + 1}")
plt.tight_layout()
plt.show()


# Learning Curves
fig, axes = plt.subplots(3, 1, figsize=(10, 15))

title = "Learning Curves CCA"
# Cross validation with 50 iterations to get smoother mean test and train
# score curves, each time with 20% data randomly selected as a validation set.
cv = ShuffleSplit(n_splits=50, test_size=0.2, random_state=0)

model = CCA()

plot_learning_curve(model, title, (X, Y), axes=axes, ylim=(0.7, 1.01), cv=cv, n_jobs=4)

plt.show()

0 comments on commit 6e5d85a

Please sign in to comment.