Skip to content

Commit

Permalink
Validation set for Gradient Based Models woohoo
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Oct 26, 2023
1 parent 6765b62 commit 44c92f6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 13 deletions.
9 changes: 3 additions & 6 deletions cca_zoo/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ class BaseModel(BaseEstimator, MultiOutputMixin, TransformerMixin):
"""

_loadings = None
_canonical_loadings = None
weights_=None

def __init__(
self,
Expand Down Expand Up @@ -276,12 +275,10 @@ def loadings_(self) -> List[np.ndarray]:
Loadings for each view.
"""
check_is_fitted(self, attributes=["weights_"])
if self._loadings is None:
# Compute loadings_ only if they haven't been computed yet
self._loadings = [
loadings=[
weights / np.linalg.norm(weights, axis=0) for weights in self.weights_
]
return self._loadings
return loadings

def explained_variance(self, views: Iterable[np.ndarray]) -> List[np.ndarray]:
"""
Expand Down
13 changes: 6 additions & 7 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,16 @@
intersphinx_mapping = {
"numpy": ("https://docs.scipy.org/doc/numpy", None),
"python": ("https://docs.python.org/3", None),
"sklearn": ("http://scikit-learn.org/dev", None),
# "torch": ("https://pytorch.org/docs/master", None),
# "jax": ("https://jax.readthedocs.io/en/latest/", None),
# "numpyro": ("https://numpyro.readthedocs.io/en/latest/", None),
# "jaxlib": ("https://jax.readthedocs.io/en/latest/", None),
# "lightning": ("https://pytorch-lightning.readthedocs.io/en/latest/", None),
"sklearn": ("https://scikit-learn.org/dev", None),
"torch": ("https://pytorch.org/docs/master", None),
"jax": ("https://jax.readthedocs.io/en/latest/", None),
"numpyro": ("https://numpyro.readthedocs.io/en/latest/", None),
"jaxlib": ("https://jax.readthedocs.io/en/latest/", None),
"lightning": ("https://pytorch-lightning.readthedocs.io/en/latest/", None),
}

autodoc_default_options = {
"members": True,
"inherited-members": True,
"show-inheritance": True,
"member-order": "bysource",
}
Expand Down

0 comments on commit 44c92f6

Please sign in to comment.