Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Oct 25, 2023
2 parents e14aac3 + 907eb06 commit c88618d
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 24 deletions.
24 changes: 11 additions & 13 deletions cca_zoo/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(

def _validate_data(self, views: Iterable[np.ndarray]):
if self.copy_data:
views= [
views = [
check_array(
view,
copy=True,
Expand All @@ -68,7 +68,7 @@ def _validate_data(self, views: Iterable[np.ndarray]):
for view in views
]
else:
views= [
views = [
check_array(
view,
copy=False,
Expand All @@ -88,7 +88,6 @@ def _validate_data(self, views: Iterable[np.ndarray]):
self.n_samples_ = views[0].shape[0]
return views


def _check_params(self):
"""
Checks the parameters of the model.
Expand Down Expand Up @@ -129,16 +128,15 @@ def transform(
"""
check_is_fitted(self)
views =[
check_array(
view,
copy=True,
accept_sparse=False,
accept_large_sparse=False,

)
for view in views
]
views = [
check_array(
view,
copy=True,
accept_sparse=False,
accept_large_sparse=False,
)
for view in views
]
transformed_views = []
for i, view in enumerate(views):
transformed_view = view @ self.weights_[i]
Expand Down
6 changes: 5 additions & 1 deletion cca_zoo/linear/_gradient/_ey.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ def __init__(self, views, batch_size=None):

def __getitem__(self, index):
views = [view[index] for view in self.views]
independent_index = index if self.batch_size is None else self.random_state.randint(0, len(self))
independent_index = (
index
if self.batch_size is None
else self.random_state.randint(0, len(self))
)
independent_views = [view[independent_index] for view in self.views]
return {"views": views, "independent_views": independent_views}
4 changes: 3 additions & 1 deletion cca_zoo/linear/_grcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def _weights(self, eigvals, eigvecs, views, feature_groups=None, **kwargs):
# Loop through c and add group means to splits if c > 0
self.splits = [
n_features + n_groups if c > 0 else n_features
for n_features, n_groups, c in zip(self.n_features_in_, self.n_groups_, self.c)
for n_features, n_groups, c in zip(
self.n_features_in_, self.n_groups_, self.c
)
]

# Add zero at the beginning and compute cumulative sum of splits
Expand Down
2 changes: 1 addition & 1 deletion cca_zoo/linear/_iterative/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def fit(self, views: Iterable[np.ndarray], y=None, **kwargs):

def _fit(self, views: Iterable[np.ndarray]):
views = self._validate_data(views)
self.random_state= check_random_state(self.random_state)
self.random_state = check_random_state(self.random_state)
self._initialize(views)
self._check_params()
# Solve using alternating optimisation across the representations until convergence
Expand Down
2 changes: 1 addition & 1 deletion cca_zoo/linear/_iterative/_scca_span.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _update_weights(self, views: np.ndarray, i: int) -> None:

def _initialize_variables(self, views):
self.max_obj = [0, 0]
cov = cross_cov(views[0], views[1],rowvar=False)
cov = cross_cov(views[0], views[1], rowvar=False)
# Perform SVD on im and obtain individual matrices
P, D, Q = np.linalg.svd(cov, full_matrices=True)
self.P = P[:, : self.rank]
Expand Down
5 changes: 3 additions & 2 deletions cca_zoo/linear/_partialcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def fit(self, views: Iterable[np.ndarray], y=None, partials=None, **kwargs):
views, y=y, partials=partials, **kwargs
) # call the parent class fit method


def _process_data(self, views, partials=None, **kwargs):
if partials is None:
return super()._process_data(views, **kwargs)
Expand All @@ -63,7 +62,9 @@ def transform(self, views: Iterable[np.ndarray], partials=None, **kwargs):
if partials is None:
return super().transform(views, **kwargs)
else:
check_is_fitted(self) # check if the model has been fitted before transforming
check_is_fitted(
self
) # check if the model has been fitted before transforming
transformed_views = []
for i, (view) in enumerate(views):
transformed_view = (
Expand Down
13 changes: 9 additions & 4 deletions cca_zoo/linear/_tcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class TCCA(MCCA):
>>> model = TCCA()
>>> model.fit((X1,X2,X3)).score((X1,X2,X3))
"""

def fit(self, views: Iterable[np.ndarray], y=None, **kwargs):
# Validate the input data
views = self._validate_data(views)
Expand All @@ -60,12 +61,16 @@ def fit(self, views: Iterable[np.ndarray], y=None, **kwargs):
M = np.expand_dims(M, -1) @ el
M = np.mean(M, 0)
tl.set_backend("numpy")
M_parafac = parafac(M, self.latent_dimensions, verbose=False, random_state=self.random_state, init="random")
M_parafac = parafac(
M,
self.latent_dimensions,
verbose=False,
random_state=self.random_state,
init="random",
)
self.weights_ = [
cov_invsqrt @ fac
for i, (cov_invsqrt, fac) in enumerate(
zip(covs_invsqrt, M_parafac.factors)
)
for i, (cov_invsqrt, fac) in enumerate(zip(covs_invsqrt, M_parafac.factors))
]
return self

Expand Down
4 changes: 3 additions & 1 deletion test/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def fit(self, views, y=None):
def transform(self, views, y=None):
return views[0]

random_state=0

random_state = 0


@pytest.mark.parametrize(
"estimator",
Expand Down

0 comments on commit c88618d

Please sign in to comment.