Skip to content

Commit

Permalink
Validation set for Gradient Based Models woohoo
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Oct 19, 2023
1 parent a710bbd commit df02b9c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
20 changes: 10 additions & 10 deletions cca_zoo/linear/_gradient/_svd.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from cca_zoo.deep.objectives import CCA_SVDLoss, PLS_SVDLoss
from cca_zoo.linear._gradient._ey import CCA_EY


class CCA_SVD(CCA_EY):
objective = CCA_SVDLoss()


class PLS_SVD(CCA_EY):
objective = PLS_SVDLoss()
# from cca_zoo.deep.objectives import CCA_SVDLoss, PLS_SVDLoss
# from cca_zoo.linear._gradient._ey import CCA_EY
#
#
# class CCA_SVD(CCA_EY):
# objective = CCA_SVDLoss()
#
#
# class PLS_SVD(CCA_EY):
# objective = PLS_SVDLoss()
2 changes: 1 addition & 1 deletion test/test_probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def setup_data():
random_state=seed,
structure="identity",
)
X, Y = data.sample(100)
X, Y = data.sample(500)
X -= X.mean(axis=0)
Y -= Y.mean(axis=0)
return X, Y, data.joint_cov
Expand Down

0 comments on commit df02b9c

Please sign in to comment.