Skip to content

Commit

Permalink
Big Update
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Nov 13, 2023
1 parent a68e4a2 commit 1894e3f
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/test_deepmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
objectives,
)
from cca_zoo.deep.data import NumpyDataset, get_dataloaders, check_dataset
from cca_zoo.linear import CCA, GCCA, MCCA
from cca_zoo.linear import CCA, GCCA, MCCA, TCCA

seed_everything(0)
rng = check_random_state(0)
Expand Down Expand Up @@ -121,7 +121,7 @@ def test_DTCCA_methods():
max_epochs = 20
# check that DTCCA is equivalent to _CCALoss for 2 representations with linear encoders
latent_dimensions = 2
cca = CCA(latent_dimensions=latent_dimensions)
tcca = TCCA(latent_dimensions=latent_dimensions)
encoder_1 = architectures.LinearEncoder(
latent_dimensions=latent_dimensions, feature_size=feature_size[0]
)
Expand All @@ -136,10 +136,10 @@ def test_DTCCA_methods():
z = dtcca.transform(train_loader)
assert (
np.testing.assert_array_almost_equal(
cca.fit((X[train_ids], Y[train_ids]))
tcca.fit((X[train_ids], Y[train_ids]))
.score((X[train_ids], Y[train_ids]))
.sum(),
cca.fit((z)).score((z)).sum(),
tcca.fit((z)).score((z)).sum(),
decimal=1,
)
is None
Expand Down

0 comments on commit 1894e3f

Please sign in to comment.