From 7b357a8b0ed673b900031bdd4a7f395ad2d85259 Mon Sep 17 00:00:00 2001 From: Niclas Rieger Date: Tue, 3 Sep 2024 20:17:44 +0200 Subject: [PATCH] feat: add POP analysis --- .../_autosummary/xeofs.single.POP.rst | 44 ++ docs/api_reference/single_set_analysis.rst | 1 + docs/auto_examples/auto_examples_jupyter.zip | Bin 105206 -> 105206 bytes docs/auto_examples/auto_examples_python.zip | Bin 62435 -> 62435 bytes tests/models/single/test_pop.py | 219 ++++++++ xeofs/cross/base_model_cross_set.py | 1 + xeofs/linalg/__init__.py | 3 + xeofs/linalg/_numpy/_svd.py | 14 +- xeofs/linalg/svd.py | 3 + xeofs/linalg/utils.py | 6 + xeofs/single/__init__.py | 2 + xeofs/single/base_model_single_set.py | 1 + xeofs/single/pop.py | 497 ++++++++++++++++++ 13 files changed, 790 insertions(+), 1 deletion(-) create mode 100644 docs/api_reference/_autosummary/xeofs.single.POP.rst create mode 100644 tests/models/single/test_pop.py create mode 100644 xeofs/linalg/utils.py create mode 100644 xeofs/single/pop.py diff --git a/docs/api_reference/_autosummary/xeofs.single.POP.rst b/docs/api_reference/_autosummary/xeofs.single.POP.rst new file mode 100644 index 00000000..8d1e6b8b --- /dev/null +++ b/docs/api_reference/_autosummary/xeofs.single.POP.rst @@ -0,0 +1,44 @@ +POP +=== + +.. currentmodule:: xeofs.single + +.. autoclass:: POP + :members: + :inherited-members: + + + .. automethod:: __init__ + + + .. rubric:: Methods + + .. autosummary:: + + ~POP.__init__ + ~POP.components + ~POP.components_amplitude + ~POP.components_phase + ~POP.compute + ~POP.damping_times + ~POP.deserialize + ~POP.eigenvalues + ~POP.fit + ~POP.fit_transform + ~POP.get_params + ~POP.get_serialization_attrs + ~POP.inverse_transform + ~POP.load + ~POP.periods + ~POP.save + ~POP.scores + ~POP.scores_amplitude + ~POP.scores_phase + ~POP.serialize + ~POP.transform + + + + + + \ No newline at end of file diff --git a/docs/api_reference/single_set_analysis.rst b/docs/api_reference/single_set_analysis.rst index 6a975b8a..744f0f20 100644 --- a/docs/api_reference/single_set_analysis.rst +++ b/docs/api_reference/single_set_analysis.rst @@ -13,6 +13,7 @@ Methods that investigate relationships or patterns between variables within a si ~xeofs.single.ComplexEOF ~xeofs.single.HilbertEOF ~xeofs.single.ExtendedEOF + ~xeofs.single.POP ~xeofs.single.OPA ~xeofs.single.GWPCA ~xeofs.single.SparsePCA diff --git a/docs/auto_examples/auto_examples_jupyter.zip b/docs/auto_examples/auto_examples_jupyter.zip index 186ccd5aee0f3a2528da5c57cb29ec78c2089b0a..fbfeb6444c953a5bceb153ac6b59fbb558410a0f 100644 GIT binary patch delta 378 zcmeyimF?SBHr@blW)=|!5SUh^ypi{#05g!@%qUnW3TBk5l|UG8tWWcTMRs{MLqr_n zr}2PAZe~vs1T*q#pK*d2;q5;(!Hl4dpCM8nhYRGvBD){yh=CceeoPbxF}BwUFhYd3 zuT^7&iaFXbo`A^mguz(Xa~UrPgLSNGWV{4sO|PHI=m`>7U#K4SKKPHMZ%hV}tuM=Pd z3V~Q_)fk~-j&_VEAhJARFxK^4#tXt=b*mZ~FA0MTnO;AY(G$!#J(bY|%n+N#=mcg& zPh)fjGZs#Rs(Cw&(GM(QH=WT7%;=uZ=niIFoeotjJp(F|Is>X};|xXzu$nJ3fK~vV pIUOh>Jw0kBqbiu!G!rT;FbgRB0wxUPIe_ICf`s#CF&e;lp#Y^hcLo3e diff --git a/docs/auto_examples/auto_examples_python.zip b/docs/auto_examples/auto_examples_python.zip index 232d002a881ef71513e630766d73fb8c71a7ee08..601f75f1255d4bd5e840faa2e5b3116b3720baf1 100644 GIT binary patch delta 350 zcmaF-ocZx{X5IjAW)=|!5SUh^ypcDRlNm^FF69j51~X2IOyB`CViftBk=+>eN0$__-e;>g_&BPrmq339P0L$QJyh1!miQvW1F4>^lbKvwen$ RPy8%BIq|bDSk1)GK>*SPiXZ?0 delta 350 zcmaF-ocZx{X5IjAW)=|!5Rk4@+Q^&A$qb}7mvRPjgBd49Ch&k6F^YWbU`DP+D<7Eg z)oc+Pm@&)tBo~;G=COzc%;5H|mH{(5GoM1F3QJ@~z#^J#Dk8k=79UvT!wD$k?DY(8u!zOuEU?2SFMa0*GH>#`cdlUC?7cmhu6XYNrgy)G z$g_O#2J?eHxP$3AA6&rn;|~yZY9Ar|+>a3TcAumtU;L;9R#OLL3x3i9Sv=Y9lP#F8 c`{W9yk3sotpCRHCKTA(e{HzO>oA@~h0K%Dhi~s-t diff --git a/tests/models/single/test_pop.py b/tests/models/single/test_pop.py new file mode 100644 index 00000000..b42e9748 --- /dev/null +++ b/tests/models/single/test_pop.py @@ -0,0 +1,219 @@ +import numpy as np +import pytest +import xarray as xr + +from xeofs.single import POP + + +def test_init(): + """Tests the initialization of the POP class""" + pop = POP(n_modes=5, standardize=True, use_coslat=True) + + # Assert preprocessor has been initialized + assert hasattr(pop, "_params") + assert hasattr(pop, "preprocessor") + assert hasattr(pop, "whitener") + + +def test_fit(mock_data_array): + pop = POP() + pop.fit(mock_data_array, "time") + + +def test_eigenvalues(mock_data_array): + pop = POP() + pop.fit(mock_data_array, "time") + + eigenvalues = pop.eigenvalues() + assert isinstance(eigenvalues, xr.DataArray) + + +def test_damping_times(mock_data_array): + pop = POP() + pop.fit(mock_data_array, "time") + + times = pop.damping_times() + assert isinstance(times, xr.DataArray) + + +def test_periods(mock_data_array): + pop = POP() + pop.fit(mock_data_array, "time") + + periods = pop.periods() + assert isinstance(periods, xr.DataArray) + + +def test_components(mock_data_array): + """Tests the components method of the POP class""" + sample_dim = ("time",) + pop = POP() + pop.fit(mock_data_array, sample_dim) + + # Test components method + components = pop.components() + feature_dims = tuple(set(mock_data_array.dims) - set(sample_dim)) + assert isinstance(components, xr.DataArray), "Components is not a DataArray" + assert set(components.dims) == set( + ("mode",) + feature_dims + ), "Components does not have the right feature dimensions" + + +def test_scores(mock_data_array): + """Tests the scores method of the POP class""" + sample_dim = ("time",) + pop = POP() + pop.fit(mock_data_array, sample_dim) + + # Test scores method + scores = pop.scores() + assert isinstance(scores, xr.DataArray), "Scores is not a DataArray" + assert set(scores.dims) == set( + (sample_dim + ("mode",)) + ), "Scores does not have the right dimensions" + + +def test_transform(mock_data_array): + """Test projecting new unseen data onto the POPs""" + dim = ("time",) + data = mock_data_array.isel({dim[0]: slice(1, None)}) + new_data = mock_data_array.isel({dim[0]: slice(0, 1)}) + + # Create a xarray DataArray with random data + model = POP(n_modes=2, solver="full") + model.fit(data, dim) + scores = model.scores() + + # Project data onto the components + projections = model.transform(data) + + # Check that the projection has the right dimensions + assert projections.dims == scores.dims, "Projection has wrong dimensions" # type: ignore + + # Check that the projection has the right data type + assert isinstance(projections, xr.DataArray), "Projection is not a DataArray" + + # Check that the projection has the right name + assert projections.name == "scores", "Projection has wrong name: {}".format( + projections.name + ) + + # Check that the projection's data is the same as the scores + np.testing.assert_allclose( + scores.sel(mode=slice(1, 3)), projections.sel(mode=slice(1, 3)), rtol=1e-3 + ) + + # Project unseen data onto the components + new_projections = model.transform(new_data) + + # Check that the projection has the right dimensions + assert new_projections.dims == scores.dims, "Projection has wrong dimensions" # type: ignore + + # Check that the projection has the right data type + assert isinstance(new_projections, xr.DataArray), "Projection is not a DataArray" + + # Check that the projection has the right name + assert new_projections.name == "scores", "Projection has wrong name: {}".format( + new_projections.name + ) + + # Ensure that the new projections are not NaNs + assert np.all(new_projections.notnull().values), "New projections contain NaNs" + + +def test_inverse_transform(mock_data_array): + """Test inverse_transform method in POP class.""" + + dim = ("time",) + # instantiate the POP class with necessary parameters + pop = POP(n_modes=20, standardize=True) + + # fit the POP model + pop.fit(mock_data_array, dim=dim) + scores = pop.scores() + + # Test with single mode + scores_selection = scores.sel(mode=1) + X_rec_1 = pop.inverse_transform(scores_selection) + assert isinstance(X_rec_1, xr.DataArray) + + # Test with single mode as list + scores_selection = scores.sel(mode=[1]) + X_rec_1_list = pop.inverse_transform(scores_selection) + assert isinstance(X_rec_1_list, xr.DataArray) + + # Single mode and list should be equal + xr.testing.assert_allclose(X_rec_1, X_rec_1_list) + + # Test with all modes + X_rec = pop.inverse_transform(scores) + assert isinstance(X_rec, xr.DataArray) + + # Check that the reconstructed data has the same dimensions as the original data + assert set(X_rec.dims) == set(mock_data_array.dims) + + +@pytest.mark.parametrize("engine", ["zarr"]) +def test_save_load(mock_data_array, tmp_path, engine): + """Test save/load methods in POP class, ensuring that we can + roundtrip the model and get the same results when transforming + data.""" + # NOTE: netcdf4 does not support complex data types, so we use only zarr here + dim = "time" + original = POP() + original.fit(mock_data_array, dim) + + # Save the POP model + original.save(tmp_path / "pop", engine=engine) + + # Check that the POP model has been saved + assert (tmp_path / "pop").exists() + + # Recreate the model from saved file + loaded = POP.load(tmp_path / "pop", engine=engine) + + # Check that the params and DataContainer objects match + assert original.get_params() == loaded.get_params() + assert all([key in loaded.data for key in original.data]) + for key in original.data: + if original.data._allow_compute[key]: + assert loaded.data[key].equals(original.data[key]) + else: + # but ensure that input data is not saved by default + assert loaded.data[key].size <= 1 + assert loaded.data[key].attrs["placeholder"] is True + + # Test that the recreated model can be used to transform new data + assert np.allclose( + original.transform(mock_data_array), loaded.transform(mock_data_array) + ) + + # The loaded model should also be able to inverse_transform new data + assert np.allclose( + original.inverse_transform(original.scores()), + loaded.inverse_transform(loaded.scores()), + ) + + +def test_serialize_deserialize_dataarray(mock_data_array): + """Test roundtrip serialization when the model is fit on a DataArray.""" + dim = "time" + model = POP() + model.fit(mock_data_array, dim) + dt = model.serialize() + rebuilt_model = POP.deserialize(dt) + assert np.allclose( + model.transform(mock_data_array), rebuilt_model.transform(mock_data_array) + ) + + +def test_serialize_deserialize_dataset(mock_dataset): + """Test roundtrip serialization when the model is fit on a Dataset.""" + dim = "time" + model = POP() + model.fit(mock_dataset, dim) + dt = model.serialize() + rebuilt_model = POP.deserialize(dt) + assert np.allclose( + model.transform(mock_dataset), rebuilt_model.transform(mock_dataset) + ) diff --git a/xeofs/cross/base_model_cross_set.py b/xeofs/cross/base_model_cross_set.py index f3cf5472..ee801947 100644 --- a/xeofs/cross/base_model_cross_set.py +++ b/xeofs/cross/base_model_cross_set.py @@ -304,6 +304,7 @@ def fit( if self.get_params()["compute"]: self.data.compute() + self._post_compute() return self diff --git a/xeofs/linalg/__init__.py b/xeofs/linalg/__init__.py index e69de29b..c2cfd896 100644 --- a/xeofs/linalg/__init__.py +++ b/xeofs/linalg/__init__.py @@ -0,0 +1,3 @@ +from .utils import total_variance + +__all__ = ["total_variance"] diff --git a/xeofs/linalg/_numpy/_svd.py b/xeofs/linalg/_numpy/_svd.py index ce796add..013997de 100644 --- a/xeofs/linalg/_numpy/_svd.py +++ b/xeofs/linalg/_numpy/_svd.py @@ -66,6 +66,7 @@ def __init__( solver: str = "auto", random_state: np.random.Generator | int | None = None, solver_kwargs: dict = {}, + is_complex: bool | str = "auto", ): sanity_check_n_modes(n_modes) self.is_based_on_variance = True if isinstance(n_modes, float) else False @@ -83,6 +84,7 @@ def __init__( self.solver = solver self.random_state = random_state self.solver_kwargs = solver_kwargs + self.is_complex = is_complex def _get_n_modes_precompute(self, rank: int) -> int: if self.is_based_on_variance: @@ -122,7 +124,17 @@ def fit_transform(self, X): # Source: https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html use_dask = True if isinstance(X, DaskArray) else False - use_complex = True if np.iscomplexobj(X) else False + + match self.is_complex: + case bool(): + use_complex = self.is_complex + case "auto": + use_complex = True if np.iscomplexobj(X) else False + case _: + raise ValueError( + f"Unrecognized value for is_complex '{self.is_complex}'. " + "Valid options are True, False, and 'auto'." + ) is_small_data = max(X.shape) < 500 diff --git a/xeofs/linalg/svd.py b/xeofs/linalg/svd.py index 137331a4..563ad1e8 100644 --- a/xeofs/linalg/svd.py +++ b/xeofs/linalg/svd.py @@ -10,6 +10,7 @@ class SVD: def __init__( self, n_modes: int | float | str, + is_complex: bool | str = "auto", init_rank_reduction: float = 0.3, flip_signs: bool = True, solver: str = "auto", @@ -20,6 +21,7 @@ def __init__( feature_name: str = "feature", ): self.n_modes = n_modes + self.is_complex = is_complex self.init_rank_reduction = init_rank_reduction self.flip_signs = flip_signs self.solver = solver @@ -54,6 +56,7 @@ def fit_transform(self, X: DataArray) -> tuple[DataArray, DataArray, DataArray]: flip_signs=self.flip_signs, solver=self.solver, random_state=self.random_state, + is_complex=self.is_complex, **self.solver_kwargs, ) U, s, V = xr.apply_ufunc( diff --git a/xeofs/linalg/utils.py b/xeofs/linalg/utils.py new file mode 100644 index 00000000..418dcf69 --- /dev/null +++ b/xeofs/linalg/utils.py @@ -0,0 +1,6 @@ +from ..utils.data_types import DataArray + + +def total_variance(X: DataArray, dim: str) -> DataArray: + """Compute the total variance of the centered data.""" + return (X * X.conj()).sum() / (X[dim].size - 1) diff --git a/xeofs/single/__init__.py b/xeofs/single/__init__.py index daf30d80..6acd8c3c 100644 --- a/xeofs/single/__init__.py +++ b/xeofs/single/__init__.py @@ -5,12 +5,14 @@ from .eof_rotator import ComplexEOFRotator, EOFRotator, HilbertEOFRotator from .gwpca import GWPCA from .opa import OPA +from .pop import POP from .sparse_pca import SparsePCA __all__ = [ "EOF", "ExtendedEOF", "SparsePCA", + "POP", "OPA", "GWPCA", "ComplexEOF", diff --git a/xeofs/single/base_model_single_set.py b/xeofs/single/base_model_single_set.py index b6d0ec7b..c4b5fdc9 100644 --- a/xeofs/single/base_model_single_set.py +++ b/xeofs/single/base_model_single_set.py @@ -156,6 +156,7 @@ def fit( if self._params["compute"]: self.data.compute() + self._post_compute() return self diff --git a/xeofs/single/pop.py b/xeofs/single/pop.py new file mode 100644 index 00000000..c783a964 --- /dev/null +++ b/xeofs/single/pop.py @@ -0,0 +1,497 @@ +import warnings + +import numpy as np +import xarray as xr +from typing_extensions import Self + +from ..linalg import total_variance +from ..preprocessing import Whitener +from ..utils.data_types import DataArray, DataObject +from ..utils.xarray_utils import argsort_dask +from .base_model_single_set import BaseModelSingleSet + + +class POP(BaseModelSingleSet): + """Principal Oscillation Pattern (POP) analysis. + + POP analysis [1]_ [2]_ is a linear multivariate technique used to identify + and describe dominant oscillatory modes in a dynamical system. POP analysis + involves computing the eigenvalues and eigenvectors of the `feedback matrix` + defined as + + .. math:: + A = C_1 C_0^{-1} + + where :math:`C_0` is the covariance matrix and :math:`C_1` is the lag-1 + covariance matrix of the input data. The eigenvectors of the feedback matrix + are the POPs and the eigenvalues are related to the damping times and + periods of the oscillatory modes. + + Parameters + ---------- + n_modes: int, default=10 + Number of modes to calculate. + center: bool, default=True + Whether to center the input data. + standardize: bool, default=False + Whether to standardize the input data. + use_coslat: bool, default=False + Whether to use cosine of latitude for scaling. + use_pca : bool, default=False + If True, perform PCA to reduce the dimensionality of the data. + n_pca_modes : int | float | str, default=0.999 + If int, specifies the number of modes to retain. If float, specifies the + fraction of variance in the (whitened) data that should be explained by + the retained modes. If "all", all modes are retained. + init_rank_reduction : float, default=0.3 + Only relevant when `use_pca=True` and `n_modes` is a float, in which + case it denotes the fraction of the initial rank to reduce the data to + via PCA as a first guess before truncating the solution to the desired + fraction of explained variance. This allows for faster computation of + PCA via randomized SVD and avoids the need to compute the full SVD. + sample_name: str, default="sample" + Name of the sample dimension. + feature_name: str, default="feature" + Name of the feature dimension. + check_nans : bool, default=True + If True, remove full-dimensional NaN features from the data, check to + ensure that NaN features match the original fit data during transform, + and check for isolated NaNs. Note: this forces eager computation of dask + arrays. If False, skip all NaN checks. In this case, NaNs should be + explicitly removed or filled prior to fitting, or SVD will fail. + compute : bool, default=True + Whether to compute elements of the model eagerly, or to defer computation. + If True, four pieces of the fit will be computed sequentially: 1) the + preprocessor scaler, 2) optional NaN checks, 3) SVD decomposition, 4) scores + and components. + random_state : int, optional + Seed for the random number generator. + solver: {"auto", "full", "randomized"}, default="auto" + Solver to use for the SVD computation. + solver_kwargs: dict, default={} + Additional keyword arguments to be passed to the SVD solver. + + References + ---------- + .. [1] Hasselmann, K. PIPs and POPs: The reduction of complex dynamical systems using principal interaction and oscillation patterns. J. Geophys. Res. 93, 11015–11021 (1988). + .. [2] von Storch, H., G. Bürger, R. Schnur, and J. von Storch, 1995: + Principal Oscillation Patterns: A Review. J. Climate, 8, 377–400, + https://doi.org/10.1175/1520-0442(1995)008<0377:POPAR>2.0.CO;2. + + + Examples + -------- + + Perform POP analysis in PC space spanned by the first 10 modes: + + >>> pop = xe.single.POP(n_modes="all", use_pca=True, n_pca_modes=10) + >>> pop.fit(X, "time) + + Get the POPs and associated time coefficients: + + >>> patterns = pop.components() + >>> scores = pop.scores() + + Reconstruct the original data using a conjugate pair of POPs: + + >>> pop_pairs = scores.sel(mode=[1, 2]) + >>> X_rec = pop.inverse_transform(pop_pairs) + + """ + + def __init__( + self, + n_modes: int = 2, + center: bool = True, + standardize: bool = False, + use_coslat: bool = False, + use_pca: bool = True, + n_pca_modes: float | int = 0.999, + pca_init_rank_reduction: float = 0.3, + check_nans=True, + sample_name: str = "sample", + feature_name: str = "feature", + compute: bool = True, + random_state: int | None = None, + solver: str = "auto", + solver_kwargs: dict = {}, + **kwargs, + ): + super().__init__( + n_modes=n_modes, + center=center, + standardize=standardize, + use_coslat=use_coslat, + check_nans=check_nans, + sample_name=sample_name, + feature_name=feature_name, + compute=compute, + random_state=random_state, + solver=solver, + solver_kwargs=solver_kwargs, + **kwargs, + ) + self.attrs.update({"model": "Principal Oscillation Pattern analysis"}) + + self.whitener = Whitener( + alpha=1.0, + use_pca=use_pca, + n_modes=n_pca_modes, + init_rank_reduction=pca_init_rank_reduction, + sample_name=sample_name, + feature_name=feature_name, + compute_svd=compute, + random_state=random_state, + solver_kwargs=solver_kwargs, + ) + + self.sorted = False + + def get_serialization_attrs(self) -> dict: + return dict( + data=self.data, + preprocessor=self.preprocessor, + whitener=self.whitener, + sorted=self.sorted, + ) + + def _np_solve_pop_system(self, X): + # Feedack matrix + A = X[1:].conj().T @ X[:-1] @ np.linalg.inv(X[:-1].conj().T @ X[:-1]) + + # Compute POPs + lbda, P = np.linalg.eig(A) + + # e-folding times /damping times + tau = -1 / np.log(abs(lbda)) + + # POP periods + with warnings.catch_warnings(record=True): + warnings.filterwarnings( + "ignore", "divide by zero encountered", RuntimeWarning + ) + + T = 2 * np.pi / np.angle(lbda) + + # POP (time) coefficients (Storch et al. 1995, equation 19) + Z = self._np_compute_pop_coefficients(X, P) + # Reconstruction of original data + # Xrec = Z @ P.T + # It seems that the signs of some columns of Xrec are flipped, probably + + return P, Z, lbda, T, tau + + def _np_compute_pop_coefficients(self, X, P): + # POP (time) coefficients (Storch et al. 1995, equation 19) + Z = np.empty((X.shape[0], P.shape[1]), dtype=complex) + for i in range(P.shape[1]): + p = P[:, i : i + 1] + pr = p.real + pi = p.imag + + M = np.array([[pr.T @ pr, pr.T @ pi], [pr.T @ pi, pi.T @ pi]]).squeeze() + Minv = np.linalg.pinv(M) + zri = Minv @ np.hstack([X @ pr, X @ pi]).T + z = zri[0] + 1j * zri[1] + Z[:, i] = z + return Z + + def _fit_algorithm(self, X: DataArray) -> Self: + sample_name = self.sample_name + feature_name = self.feature_name + + # Transform in PC space + X = self.whitener.fit_transform(X) + + P, Z, lbda, T, tau = xr.apply_ufunc( + self._np_solve_pop_system, + X, + input_core_dims=[[sample_name, feature_name]], + output_core_dims=[ + [feature_name, "mode"], + [sample_name, "mode"], + ["mode"], + ["mode"], + ["mode"], + ], + dask="allowed", + ) + + mode_coords = np.arange(1, P.mode.size + 1) + P = P.assign_coords(mode=mode_coords) + Z = Z.assign_coords(mode=mode_coords) + lbda = lbda.assign_coords(mode=mode_coords) + T = T.assign_coords(mode=mode_coords) + tau = tau.assign_coords(mode=mode_coords) + + # Compute dynamical importance of each mode + var_Z = Z.var(sample_name) + norms = (var_Z) ** (0.5) + + # Compute total variance + var_tot = total_variance(X, sample_name) + + # Reorder according to variance + idx_modes_sorted = argsort_dask(norms, "mode")[::-1] # type: ignore + idx_modes_sorted.coords.update(norms.coords) + + P = self.whitener.inverse_transform_components(P) + + # Store the results + self.data.add(X, "input_data", allow_compute=False) + self.data.add(P, "components") + self.data.add(Z, "scores") + self.data.add(norms, "norms") + self.data.add(lbda, "eigenvalues") + self.data.add(tau, "damping_times") + self.data.add(T, "periods") + self.data.add(idx_modes_sorted, "idx_modes_sorted") + self.data.add(var_tot, "total_variance") + + self.data.set_attrs(self.attrs) + return self + + def _post_compute(self): + """Leave sorting until after compute because it can't be done lazily.""" + self._sort_by_variance() + + def _sort_by_variance(self): + """Re-sort the mode dimension of all data variables by variance explained.""" + if not self.sorted: + for key in self.data.keys(): + if "mode" in self.data[key].dims and key != "idx_modes_sorted": + self.data[key] = ( + self.data[key] + .isel(mode=self.data["idx_modes_sorted"].values) + .assign_coords(mode=self.data[key].mode) + ) + self.sorted = True + + def _transform_algorithm(self, X: DataArray) -> DataArray: + sample_name = self.sample_name + feature_name = self.feature_name + + P = self.data["components"] + + # Transform into PC spcae + P = self.whitener.transform_components(P) + X = self.whitener.transform(X) + + # Project the data + Z = xr.apply_ufunc( + self._np_compute_pop_coefficients, + X, + P, + input_core_dims=[[sample_name, feature_name], [feature_name, "mode"]], + output_core_dims=[[sample_name, "mode"]], + dask="allowed", + ) + Z.name = "scores" + + Z = self.whitener.inverse_transform_scores(Z) + + return Z + + def _inverse_transform_algorithm(self, scores: DataArray) -> DataArray: + """Reconstruct the original data from transformed data. + + Parameters + ---------- + scores: DataArray + Transformed data to be reconstructed. This could be a subset + of the `scores` data of a fitted model, or unseen data. Must + have a 'mode' dimension. + + Returns + ------- + data: DataObject + Reconstructed data. + + """ + # Reconstruct the data + P = self.data["components"].sel(mode=scores.mode) + + # Transform in PC space + P = self.whitener.transform_components(P) + + reconstructed_data = xr.dot(scores, P, dims="mode") + reconstructed_data.name = "reconstructed_data" + + # Inverse transform the data into physical space + reconstructed_data = self.whitener.inverse_transform_data(reconstructed_data) + + return reconstructed_data + + def components(self) -> DataObject: + """Return the POPs. + + The POPs are the eigenvectors of the feedback matrix. + + Returns + ------- + components: DataObject + Principal Oscillation Patterns (POPs). + + """ + return super().components(normalized=False) + + def scores(self, normalized: bool = False) -> DataArray: + """Return the POP coefficients/scores. + + Parameters + ---------- + normalized : bool, default=True + Whether to normalize the scores by the L2 norm. + + Returns + ------- + components: DataObject + POP coefficients. + + """ + return super().scores(normalized=normalized) + + def eigenvalues(self) -> DataArray: + """Return the eigenvalues of the feedback matrix. + + Returns + ------- + DataArray + Real or complex eigenvalues. + + """ + return self.data["eigenvalues"] + + def damping_times(self) -> DataArray: + """Return the damping times of the feedback matrix. + + The damping times are defined as + + .. math:: + \\tau = -\\frac{1}{\\log(|\\lambda|)} + + where :math:`\\lambda` is the eigenvalue. + + Returns + ------- + DataArray + Damping times. + + """ + return self.data["damping_times"] + + def periods(self) -> DataArray: + """Return the periods of the feedback matrix. + + For complex eigenvalues, the periods are defined as + + .. math:: + T = \\frac{2\\pi}{\\arg(\\lambda)} + + where :math:`\\lambda` is the eigenvalue. For real eigenvalues ``inf`` + is returned. + + Returns + ------- + DataArray + Periods. + + """ + return self.data["periods"] + + def components_amplitude(self) -> DataObject: + """Return the amplitude of the POP components. + + The amplitude of the components are defined as + + .. math:: + A_{ij} = |C_{ij}| + + where :math:`C_{ij}` is the :math:`i`-th entry of the :math:`j`-th component and + :math:`|\\cdot|` denotes the absolute value. + + + Returns + ------- + components_amplitude: DataObject + Amplitude of the components of the fitted model. + + """ + amplitudes = abs(self.data["components"]) + + amplitudes.name = "components_amplitude" + return self.preprocessor.inverse_transform_components(amplitudes) + + def components_phase(self) -> DataObject: + """Return the phase of the POP components. + + The phase of the components are defined as + + .. math:: + \\phi_{ij} = \\arg(C_{ij}) + + where :math:`C_{ij}` is the :math:`i`-th entry of the :math:`j`-th component and + :math:`\\arg(\\cdot)` denotes the argument of a complex number. + + Returns + ------- + components_phase: DataObject + Phase of the components of the fitted model. + + """ + comps = self.data["components"] + comp_phase = xr.apply_ufunc(np.angle, comps, dask="allowed", keep_attrs=True) + comp_phase.name = "components_phase" + return self.preprocessor.inverse_transform_components(comp_phase) + + def scores_amplitude(self, normalized=True) -> DataArray: + """Return the amplitude of the POP coefficients/scores. + + The amplitude of the scores are defined as + + .. math:: + A_{ij} = |S_{ij}| + + where :math:`S_{ij}` is the :math:`i`-th entry of the :math:`j`-th score and + :math:`|\\cdot|` denotes the absolute value. + + Parameters + ---------- + normalized : bool, default=True + Whether to normalize the scores by the singular values. + + Returns + ------- + scores_amplitude: DataObject + Amplitude of the scores of the fitted model. + + """ + scores = self.data["scores"].copy() + if normalized: + scores = scores / self.data["norms"] + + amplitudes = abs(scores) + amplitudes.name = "scores_amplitude" + return self.preprocessor.inverse_transform_scores(amplitudes) + + def scores_phase(self) -> DataArray: + """Return the phase of the POP coefficients/scores. + + The phase of the scores are defined as + + .. math:: + \\phi_{ij} = \\arg(S_{ij}) + + where :math:`S_{ij}` is the :math:`i`-th entry of the :math:`j`-th score and + :math:`\\arg(\\cdot)` denotes the argument of a complex number. + + Returns + ------- + scores_phase: DataObject + Phase of the scores of the fitted model. + + """ + scores = self.data["scores"] + phases = xr.apply_ufunc(np.angle, scores, dask="allowed", keep_attrs=True) + phases.name = "scores_phase" + return self.preprocessor.inverse_transform_scores(phases)