From 907eb061c809aad574bbeffe9a1cfe7a8695f2b9 Mon Sep 17 00:00:00 2001 From: jameschapman19 Date: Wed, 25 Oct 2023 13:44:48 +0000 Subject: [PATCH] Format code with black --- cca_zoo/_base.py | 24 +++++++++++------------- cca_zoo/linear/_gradient/_ey.py | 6 +++++- cca_zoo/linear/_grcca.py | 4 +++- cca_zoo/linear/_iterative/_base.py | 2 +- cca_zoo/linear/_iterative/_scca_span.py | 2 +- cca_zoo/linear/_partialcca.py | 5 +++-- cca_zoo/linear/_tcca.py | 13 +++++++++---- test/test_common.py | 4 +++- 8 files changed, 36 insertions(+), 24 deletions(-) diff --git a/cca_zoo/_base.py b/cca_zoo/_base.py index 54c9f86c..a1454e64 100644 --- a/cca_zoo/_base.py +++ b/cca_zoo/_base.py @@ -56,7 +56,7 @@ def __init__( def _validate_data(self, views: Iterable[np.ndarray]): if self.copy_data: - views= [ + views = [ check_array( view, copy=True, @@ -68,7 +68,7 @@ def _validate_data(self, views: Iterable[np.ndarray]): for view in views ] else: - views= [ + views = [ check_array( view, copy=False, @@ -88,7 +88,6 @@ def _validate_data(self, views: Iterable[np.ndarray]): self.n_samples_ = views[0].shape[0] return views - def _check_params(self): """ Checks the parameters of the model. @@ -129,16 +128,15 @@ def transform( """ check_is_fitted(self) - views =[ - check_array( - view, - copy=True, - accept_sparse=False, - accept_large_sparse=False, - - ) - for view in views - ] + views = [ + check_array( + view, + copy=True, + accept_sparse=False, + accept_large_sparse=False, + ) + for view in views + ] transformed_views = [] for i, view in enumerate(views): transformed_view = view @ self.weights_[i] diff --git a/cca_zoo/linear/_gradient/_ey.py b/cca_zoo/linear/_gradient/_ey.py index 599e0089..40662bd8 100644 --- a/cca_zoo/linear/_gradient/_ey.py +++ b/cca_zoo/linear/_gradient/_ey.py @@ -106,6 +106,10 @@ def __init__(self, views, batch_size=None): def __getitem__(self, index): views = [view[index] for view in self.views] - independent_index = index if self.batch_size is None else self.random_state.randint(0, len(self)) + independent_index = ( + index + if self.batch_size is None + else self.random_state.randint(0, len(self)) + ) independent_views = [view[independent_index] for view in self.views] return {"views": views, "independent_views": independent_views} diff --git a/cca_zoo/linear/_grcca.py b/cca_zoo/linear/_grcca.py index 08634c5b..96cacd58 100644 --- a/cca_zoo/linear/_grcca.py +++ b/cca_zoo/linear/_grcca.py @@ -62,7 +62,9 @@ def _weights(self, eigvals, eigvecs, views, feature_groups=None, **kwargs): # Loop through c and add group means to splits if c > 0 self.splits = [ n_features + n_groups if c > 0 else n_features - for n_features, n_groups, c in zip(self.n_features_in_, self.n_groups_, self.c) + for n_features, n_groups, c in zip( + self.n_features_in_, self.n_groups_, self.c + ) ] # Add zero at the beginning and compute cumulative sum of splits diff --git a/cca_zoo/linear/_iterative/_base.py b/cca_zoo/linear/_iterative/_base.py index b4b89c2d..8352cd8e 100644 --- a/cca_zoo/linear/_iterative/_base.py +++ b/cca_zoo/linear/_iterative/_base.py @@ -51,7 +51,7 @@ def fit(self, views: Iterable[np.ndarray], y=None, **kwargs): def _fit(self, views: Iterable[np.ndarray]): views = self._validate_data(views) - self.random_state= check_random_state(self.random_state) + self.random_state = check_random_state(self.random_state) self._initialize(views) self._check_params() # Solve using alternating optimisation across the representations until convergence diff --git a/cca_zoo/linear/_iterative/_scca_span.py b/cca_zoo/linear/_iterative/_scca_span.py index b67bfdbf..13269d48 100644 --- a/cca_zoo/linear/_iterative/_scca_span.py +++ b/cca_zoo/linear/_iterative/_scca_span.py @@ -99,7 +99,7 @@ def _update_weights(self, views: np.ndarray, i: int) -> None: def _initialize_variables(self, views): self.max_obj = [0, 0] - cov = cross_cov(views[0], views[1],rowvar=False) + cov = cross_cov(views[0], views[1], rowvar=False) # Perform SVD on im and obtain individual matrices P, D, Q = np.linalg.svd(cov, full_matrices=True) self.P = P[:, : self.rank] diff --git a/cca_zoo/linear/_partialcca.py b/cca_zoo/linear/_partialcca.py index 5bfe82e9..38d62ec7 100644 --- a/cca_zoo/linear/_partialcca.py +++ b/cca_zoo/linear/_partialcca.py @@ -42,7 +42,6 @@ def fit(self, views: Iterable[np.ndarray], y=None, partials=None, **kwargs): views, y=y, partials=partials, **kwargs ) # call the parent class fit method - def _process_data(self, views, partials=None, **kwargs): if partials is None: return super()._process_data(views, **kwargs) @@ -63,7 +62,9 @@ def transform(self, views: Iterable[np.ndarray], partials=None, **kwargs): if partials is None: return super().transform(views, **kwargs) else: - check_is_fitted(self) # check if the model has been fitted before transforming + check_is_fitted( + self + ) # check if the model has been fitted before transforming transformed_views = [] for i, (view) in enumerate(views): transformed_view = ( diff --git a/cca_zoo/linear/_tcca.py b/cca_zoo/linear/_tcca.py index 5fd018d5..02f0a8d2 100644 --- a/cca_zoo/linear/_tcca.py +++ b/cca_zoo/linear/_tcca.py @@ -38,6 +38,7 @@ class TCCA(MCCA): >>> model = TCCA() >>> model.fit((X1,X2,X3)).score((X1,X2,X3)) """ + def fit(self, views: Iterable[np.ndarray], y=None, **kwargs): # Validate the input data views = self._validate_data(views) @@ -60,12 +61,16 @@ def fit(self, views: Iterable[np.ndarray], y=None, **kwargs): M = np.expand_dims(M, -1) @ el M = np.mean(M, 0) tl.set_backend("numpy") - M_parafac = parafac(M, self.latent_dimensions, verbose=False, random_state=self.random_state, init="random") + M_parafac = parafac( + M, + self.latent_dimensions, + verbose=False, + random_state=self.random_state, + init="random", + ) self.weights_ = [ cov_invsqrt @ fac - for i, (cov_invsqrt, fac) in enumerate( - zip(covs_invsqrt, M_parafac.factors) - ) + for i, (cov_invsqrt, fac) in enumerate(zip(covs_invsqrt, M_parafac.factors)) ] return self diff --git a/test/test_common.py b/test/test_common.py index 96de61a5..a425bde6 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -54,7 +54,9 @@ def fit(self, views, y=None): def transform(self, views, y=None): return views[0] -random_state=0 + +random_state = 0 + @pytest.mark.parametrize( "estimator",