Skip to content

Commit

Permalink
Validation set for Gradient Based Models woohoo
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Oct 24, 2023
1 parent 24acc28 commit 7436d60
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 36 deletions.
27 changes: 1 addition & 26 deletions cca_zoo/probabilistic/_cca.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
copy_data=True,
random_state: int = 0,
learning_rate=1e-1,
n_iter=40000,
n_iter=20000,
num_samples=5000,
num_warmup=5000,
):
Expand Down Expand Up @@ -214,31 +214,6 @@ def _guide(self, views):
with numpyro.plate("n", self.n_samples_):
z = numpyro.sample("z", dist.MultivariateNormal(z_loc, jnp.diag(z_scale)))

def transform(self, views: Iterable[np.ndarray], y=None, return_std=False):
"""
Transform the data into the latent space.
Parameters
----------
views : Iterable[np.ndarray]
A list or tuple of numpy arrays representing different representations of the same samples. Each numpy array must have the same number of rows.
y: Any, optional
Ignored in this implementation.
Returns
-------
representations : np.ndarray
The transformed data in the latent space.
"""
svi = SVI(
self._model,
self._guide,
numpyro.optim.Adam(self.learning_rate),
loss=numpyro.infer.Trace_ELBO(),
)
svi_result = svi.run(self.rng_key, self.n_iter, views)
return np.array(svi_result.params["z"])

def render(self, views):
# check if graphviz is installed
try:
Expand Down
2 changes: 1 addition & 1 deletion test/test_explained_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_transformed_covariance_ratio(toy_model, synthetic_views):
pls_cov_ratios = pls.explained_covariance_ratio(synthetic_views)
# sum of these should be 1 within a small tolerance
assert np.isclose(
np.sum(pls_cov_ratios), 1, atol=1e-2
np.sum(pls_cov_ratios), 1, atol=2e-2
), "Expected sum of ratios to be 1"

cov_ratios = toy_model.explained_covariance_ratio(synthetic_views)
Expand Down
26 changes: 19 additions & 7 deletions test/test_probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import pytest

from cca_zoo.datasets import LatentVariableData
from cca_zoo.linear import CCA
from cca_zoo.linear import CCA, PLS
from cca_zoo.probabilistic import ProbabilisticCCA
from cca_zoo.probabilistic._pls import ProbabilisticPLS


@pytest.fixture
Expand Down Expand Up @@ -32,13 +33,24 @@ def test_cca_vs_probabilisticCCA(setup_data):

# Assert: Calculate correlation coefficient and ensure it's greater than 0.95
z = cca.transform([X, Y])[0]
z_p = np.array(pcca.transform([X, Y]))
# correlation between cca and pcca
correlation_matrix = np.abs(np.corrcoef(z.reshape(-1), z_p.reshape(-1)))
correlation_matrix = np.abs(np.corrcoef(z.reshape(-1), pcca.params["z_loc"].reshape(-1)))
correlation = correlation_matrix[0, 1]

assert correlation > 0.8

#
# assert (
# correlation > 0.9
# ), f"Expected correlation greater than 0.95, got {correlation}"
def test_pls_vs_probabilisticPLS(setup_data):
X, Y, data = setup_data
# Models and fit
pls = PLS(latent_dimensions=1)
ppls = ProbabilisticPLS(latent_dimensions=1, random_state=10)
pls.fit([X, Y])
ppls.fit([X, Y])

# Assert: Calculate correlation coefficient and ensure it's greater than 0.95
z = pls.transform([X, Y])[0]
# correlation between cca and pcca
correlation_matrix = np.abs(np.corrcoef(z.reshape(-1), ppls.params["z_loc"].reshape(-1)))
correlation = correlation_matrix[0, 1]

assert correlation > 0.8
4 changes: 2 additions & 2 deletions test/test_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# Fixtures
@pytest.fixture
def simulated_data():
data_generator = JointData(view_features=[4, 6], latent_dims=5, correlation=0.8)
data_generator = JointData(view_features=[4, 6], latent_dims=3, correlation=0.8)
X, Y = data_generator.sample(50)
return X, Y

Expand All @@ -32,7 +32,7 @@ def test_sequential_model_fits_and_identifies_effects(
X, Y = simulated_data
sequential_model = SequentialModel(
grid_search_estimator,
latent_dimensions=10,
latent_dimensions=4,
permutation_test=True,
p_threshold=0.05,
)
Expand Down

0 comments on commit 7436d60

Please sign in to comment.